mirror of https://github.com/hpcaitech/ColossalAI
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>pull/5242/head
parent
dd2c28a323
commit
d202cc28c0
|
@ -47,7 +47,7 @@ jobs:
|
|||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
|
|
@ -79,7 +79,7 @@ jobs:
|
|||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
options: --gpus all --rm -v /data/scratch/examples-data:/data/
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 15
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-run-example-${{ matrix.directory }}
|
||||
cancel-in-progress: true
|
||||
|
|
|
@ -35,7 +35,7 @@ jobs:
|
|||
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}
|
||||
container:
|
||||
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: 📚 Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
|
|
@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, DistributedSampler
|
|||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .base import OnPolicyTrainer
|
||||
from .callbacks import Callback
|
||||
|
@ -105,7 +105,7 @@ class PPOTrainer(OnPolicyTrainer):
|
|||
self.critic_optim = critic_optim
|
||||
|
||||
self.offload_inference_models = offload_inference_models
|
||||
self.device = get_current_device()
|
||||
self.device = get_accelerator().get_current_device()
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch.nn as nn
|
|||
import colossalai
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
|
||||
from .ddp import DDPStrategy
|
||||
|
@ -158,9 +157,19 @@ class GeminiStrategy(DDPStrategy):
|
|||
|
||||
warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
|
||||
|
||||
# colossalai has changed api for get_current_device in 0.3.4 version or newer
|
||||
try:
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
chunk_init_device = get_accelerator().get_current_device()
|
||||
except:
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
chunk_init_device = get_current_device()
|
||||
|
||||
# NOTE: dist should be initialized before calling get_current_device()
|
||||
plugin_initializer = lambda: GeminiPlugin(
|
||||
chunk_init_device=get_current_device(),
|
||||
chunk_init_device=chunk_init_device,
|
||||
placement_policy=placement_policy,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
|
|
|
@ -4,41 +4,34 @@
|
|||
Continual Pre-training of LLaMA-2 developed by Colossal-AI Team
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_llama2.dataset.loader import (
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
)
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import (
|
||||
GeminiPlugin,
|
||||
LowLevelZeroPlugin,
|
||||
HybridParallelPlugin,
|
||||
)
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossal_llama2.dataset.loader import (
|
||||
load_tokenized_dataset,
|
||||
setup_distributed_dataloader,
|
||||
DataCollatorForSupervisedDataset,
|
||||
StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
from colossal_llama2.utils.flash_attention_patch import replace_with_flash_attention
|
||||
from colossal_llama2.utils.ckpt_io import load_checkpoint, save_checkpoint
|
||||
from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||
|
||||
|
||||
def get_model_numel(model: torch.nn.Module) -> int:
|
||||
|
@ -215,9 +208,18 @@ def main() -> None:
|
|||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
)
|
||||
|
||||
# colossalai has changed api for get_current_device in 0.3.4 version or newer
|
||||
try:
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
current_device = get_accelerator().get_current_device()
|
||||
except:
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
current_device = get_current_device()
|
||||
|
||||
init_ctx = LazyInitContext(default_device=current_device) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
with init_ctx:
|
||||
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||
# Freeze part of parameters.
|
||||
|
@ -320,7 +322,7 @@ def main() -> None:
|
|||
initial=start_step,
|
||||
) as pbar:
|
||||
for step, batch in pbar:
|
||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
batch = {k: v.to(current_device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||
|
||||
batch_output = model(**batch)
|
||||
|
||||
|
@ -372,9 +374,7 @@ def main() -> None:
|
|||
# Final save.
|
||||
coordinator.print_on_master("Start saving final model checkpoint")
|
||||
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
||||
coordinator.print_on_master(
|
||||
f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}"
|
||||
)
|
||||
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
||||
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .api import auto_set_accelerator, get_accelerator, set_accelerator
|
||||
from .base_accelerator import BaseAccelerator
|
||||
from .cpu_accelerator import CpuAccelerator
|
||||
from .cuda_accelerator import CudaAccelerator
|
||||
from .npu_accelerator import NpuAccelerator
|
||||
|
||||
|
@ -10,4 +11,5 @@ __all__ = [
|
|||
"BaseAccelerator",
|
||||
"CudaAccelerator",
|
||||
"NpuAccelerator",
|
||||
"CpuAccelerator",
|
||||
]
|
||||
|
|
|
@ -3,6 +3,7 @@ from collections import OrderedDict
|
|||
from typing import Union
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
from .cpu_accelerator import CpuAccelerator
|
||||
from .cuda_accelerator import CudaAccelerator
|
||||
from .npu_accelerator import NpuAccelerator
|
||||
|
||||
|
@ -15,7 +16,7 @@ _ACCELERATOR = None
|
|||
# we use ordered dictionary here to associate the
|
||||
# order with device check priority
|
||||
# i.e. auto_set_accelerator will check cuda first
|
||||
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator)
|
||||
_ACCELERATOR_MAPPING = OrderedDict(cuda=CudaAccelerator, npu=NpuAccelerator, cpu=CpuAccelerator)
|
||||
|
||||
|
||||
def set_accelerator(accelerator: Union[str, BaseAccelerator]) -> None:
|
||||
|
@ -43,19 +44,17 @@ def auto_set_accelerator() -> None:
|
|||
"""
|
||||
global _ACCELERATOR
|
||||
|
||||
for _, accelerator_cls in _ACCELERATOR_MAPPING.items():
|
||||
for accelerator_name, accelerator_cls in _ACCELERATOR_MAPPING.items():
|
||||
try:
|
||||
accelerator = accelerator_cls()
|
||||
if accelerator.is_available():
|
||||
if accelerator_name == "cpu" or accelerator.is_available():
|
||||
_ACCELERATOR = accelerator
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
||||
if _ACCELERATOR is None:
|
||||
raise RuntimeError(
|
||||
f"No accelerator is available. Please check your environment. The list of accelerators we support is {list(_ACCELERATOR_MAPPING.keys())}"
|
||||
)
|
||||
raise RuntimeError("No accelerator is available.")
|
||||
|
||||
|
||||
def get_accelerator() -> BaseAccelerator:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -8,6 +9,8 @@ __all__ = ["BaseAccelerator"]
|
|||
|
||||
|
||||
class BaseAccelerator(ABC):
|
||||
support_set_device: bool = True
|
||||
|
||||
def __init__(self, name: str, communication_backend: str, is_synchronous: bool) -> None:
|
||||
self._name = name
|
||||
self._communication_backend = communication_backend
|
||||
|
@ -45,6 +48,12 @@ class BaseAccelerator(ABC):
|
|||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
@abstractmethod
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""
|
||||
Return the current device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
|
@ -52,7 +61,7 @@ class BaseAccelerator(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_device(self, device: Union[torch.device, int]) -> None:
|
||||
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
|
@ -79,3 +88,226 @@ class BaseAccelerator(ABC):
|
|||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
|
||||
def set_to_device(self, models: Any) -> Any:
|
||||
"""
|
||||
Send model to device.
|
||||
|
||||
:param models: nn.module or a list of module
|
||||
"""
|
||||
if isinstance(models, list) and len(models) > 1:
|
||||
ret = []
|
||||
for model in models:
|
||||
ret.append(model.to(self.get_current_device()))
|
||||
return ret
|
||||
elif isinstance(models, list):
|
||||
return models[0].to(self.get_current_device())
|
||||
else:
|
||||
return models.to(self.get_current_device())
|
||||
|
||||
@abstractmethod
|
||||
def get_device_capability(self, device=None) -> Tuple[int, int]:
|
||||
"""
|
||||
Gets the capability of a device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_device_name(self, device=None) -> str:
|
||||
"""
|
||||
Gets the name of a device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_device_properties(self, device):
|
||||
"""
|
||||
Gets the properties of a device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def utilization(self, device=None) -> int:
|
||||
"""
|
||||
Returns the percent of time over the past sample period during which one or more kernels was executing on the device as given by nvidia-smi or npu-smi, etc.
|
||||
"""
|
||||
|
||||
# =======================
|
||||
# random number generator APIs
|
||||
# =======================
|
||||
@abstractmethod
|
||||
def get_rng_state(self, device="cuda") -> torch.Tensor:
|
||||
"""
|
||||
Returns the random number generator state of the specified device as a ByteTensor.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_rng_state_all(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of ByteTensor representing the random number states of all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
|
||||
"""
|
||||
Sets the random number generator state of the specified device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
|
||||
"""
|
||||
Sets the random number generator state of all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def manual_seed(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers for the current device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def manual_seed_all(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers on all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def seed(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number for the current device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def seed_all(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number on all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def initial_seed(self) -> int:
|
||||
"""
|
||||
Returns the current random seed of the current device.
|
||||
"""
|
||||
|
||||
# =======================
|
||||
# memory management APIs
|
||||
# =======================
|
||||
@abstractmethod
|
||||
def empty_cache(self) -> None:
|
||||
"""
|
||||
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other device application and visible in nvidia-smi.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_stats(self, device=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of CUDA memory allocator statistics for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_summary(self, device=None, abbreviated=False) -> str:
|
||||
"""
|
||||
Returns a human-readable printout of the current memory allocator statistics for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_snapshot(self):
|
||||
"""
|
||||
Returns a snapshot of the CUDA memory allocator state across all devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current device memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def max_memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum device memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset_max_memory_allocated(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum device memory occupied by tensors for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset_max_memory_cached(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum device memory managed by the caching allocator for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current device memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def max_memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum device memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
|
||||
"""
|
||||
Set memory fraction for a process.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset_peak_memory_stats(self, device=None) -> None:
|
||||
"""
|
||||
Resets the "peak" stats tracked by the device memory allocator.
|
||||
"""
|
||||
|
||||
# =======================
|
||||
# streams and events APIs
|
||||
# =======================
|
||||
|
||||
@abstractmethod
|
||||
def Stream(self, device=None, priority=0, **kwargs):
|
||||
"""
|
||||
A device stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||
"""
|
||||
device events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def current_stream(self, device=None):
|
||||
"""
|
||||
Returns the currently selected Stream for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def default_stream(self, device=None):
|
||||
"""
|
||||
Returns the default Stream for a given device.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_stream(self, stream_):
|
||||
"""
|
||||
Sets the current stream.This is a wrapper API to set the stream.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def stream(self, stream_):
|
||||
"""
|
||||
Wrapper around the Context-manager StreamContext that selects a given stream.
|
||||
"""
|
||||
|
||||
# =======================
|
||||
# amp APIs
|
||||
# =======================
|
||||
@abstractmethod
|
||||
def autocast(
|
||||
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Return autocast function
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import resource
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
__all__ = ["CpuAccelerator"]
|
||||
|
||||
|
||||
class CpuAccelerator(BaseAccelerator):
|
||||
support_set_device: bool = False
|
||||
"""
|
||||
Accelerator class for cpu.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(name="cpu", communication_backend="gloo", is_synchronous=False)
|
||||
|
||||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""
|
||||
Return the current device.
|
||||
"""
|
||||
return torch.device("cpu")
|
||||
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
"""
|
||||
Return the name of the device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def synchronize(self, device: Union[torch.device, int] = None):
|
||||
"""
|
||||
Synchronize the current process.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def is_available(self):
|
||||
"""
|
||||
Check if the accelerator is available.
|
||||
"""
|
||||
return True
|
||||
|
||||
def device_count(self):
|
||||
"""
|
||||
Return the number of devices on the machine.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def get_device_capability(self, device=None) -> Tuple[int, int]:
|
||||
"""
|
||||
Gets the cuda capability of a device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def get_device_name(self, device=None) -> str:
|
||||
"""
|
||||
Gets the name of a device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def get_device_properties(self, device):
|
||||
"""
|
||||
Gets the properties of a device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def utilization(self, device=None) -> int:
|
||||
"""
|
||||
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
# =======================
|
||||
# random number generator APIs
|
||||
# =======================
|
||||
def get_rng_state(self, device=None) -> torch.Tensor:
|
||||
"""
|
||||
Returns the random number generator state of the specified GPU as a ByteTensor.
|
||||
"""
|
||||
return torch.get_rng_state(device)
|
||||
|
||||
def get_rng_state_all(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of ByteTensor representing the random number states of all devices.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def set_rng_state(self, new_state: torch.ByteTensor, device: str = None) -> None:
|
||||
"""
|
||||
Sets the random number generator state of the specified GPU.
|
||||
"""
|
||||
torch.set_rng_state(new_state)
|
||||
|
||||
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
|
||||
"""
|
||||
Sets the random number generator state of all devices.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def manual_seed(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers for the current GPU.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def manual_seed_all(self, seed: int) -> None:
|
||||
"""
|
||||
Set the random seed for the all processes.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def seed(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number for the current GPU.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def seed_all(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number on all GPUs.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def initial_seed(self) -> int:
|
||||
"""
|
||||
Returns the current random seed of the current GPU.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
# =======================
|
||||
# memory management APIs
|
||||
# =======================
|
||||
|
||||
def empty_cache(self) -> None:
|
||||
"""
|
||||
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_stats(self, device=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of CUDA memory allocator statistics for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_summary(self, device=None, abbreviated=False) -> str:
|
||||
"""
|
||||
Returns a human-readable printout of the current memory allocator statistics for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_snapshot(self):
|
||||
"""
|
||||
Returns a snapshot of the CUDA memory allocator state across all devices.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def max_memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def reset_max_memory_allocated(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def reset_max_memory_cached(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return psutil.Process().memory_info().rss
|
||||
|
||||
def max_memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
|
||||
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
|
||||
"""
|
||||
Set memory fraction for a process.
|
||||
"""
|
||||
max_memory = int(psutil.virtual_memory().total * fraction)
|
||||
_, hard = resource.getrlimit(resource.RLIMIT_AS)
|
||||
resource.setrlimit(resource.RLIMIT_AS, (max_memory, hard))
|
||||
|
||||
def reset_peak_memory_stats(self, device=None) -> None:
|
||||
"""
|
||||
Resets the "peak" stats tracked by the CUDA memory allocator.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
# =======================
|
||||
# streams and events APIs
|
||||
# =======================
|
||||
|
||||
def Stream(self, device=None, priority=0, **kwargs):
|
||||
"""
|
||||
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||
"""
|
||||
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def current_stream(self, device=None):
|
||||
"""
|
||||
Returns the currently selected Stream for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def default_stream(self, device=None):
|
||||
"""
|
||||
Returns the default Stream for a given device.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def set_stream(self, stream_):
|
||||
"""
|
||||
Sets the current stream.This is a wrapper API to set the stream.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
def stream(self, stream_):
|
||||
"""
|
||||
Wrapper around the Context-manager StreamContext that selects a given stream.
|
||||
"""
|
||||
raise RuntimeError("this method is not supported for cpu accelerator")
|
||||
|
||||
# =======================
|
||||
# amp APIs
|
||||
# =======================
|
||||
def autocast(
|
||||
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Return autocast function
|
||||
"""
|
||||
return nullcontext
|
|
@ -1,7 +1,9 @@
|
|||
#!/usr/bin/env python
|
||||
from typing import Union
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
|
@ -19,16 +21,26 @@ class CudaAccelerator(BaseAccelerator):
|
|||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""
|
||||
Return the current device.
|
||||
"""
|
||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def set_device(self, device: Union[torch.device, int]) -> None:
|
||||
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
if device is None:
|
||||
if not dist.is_initialized():
|
||||
raise RuntimeError("Cannot get current device when distributed is not initialized.")
|
||||
device = dist.get_rank() % self.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
|
@ -54,3 +66,211 @@ class CudaAccelerator(BaseAccelerator):
|
|||
Return the number of devices on the machine.
|
||||
"""
|
||||
return torch.cuda.device_count()
|
||||
|
||||
def get_device_capability(self, device=None) -> Tuple[int, int]:
|
||||
"""
|
||||
Gets the cuda capability of a device.
|
||||
"""
|
||||
return torch.cuda.get_device_capability(device)
|
||||
|
||||
def get_device_name(self, device=None) -> str:
|
||||
"""
|
||||
Gets the name of a device.
|
||||
"""
|
||||
return torch.cuda.get_device_name(device)
|
||||
|
||||
def get_device_properties(self, device):
|
||||
"""
|
||||
Gets the properties of a device.
|
||||
"""
|
||||
return torch.cuda.get_device_properties(device)
|
||||
|
||||
def utilization(self, device=None) -> int:
|
||||
"""
|
||||
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
|
||||
"""
|
||||
return torch.cuda.utilization(device)
|
||||
|
||||
# =======================
|
||||
# random number generator APIs
|
||||
# =======================
|
||||
def get_rng_state(self, device="cuda") -> torch.Tensor:
|
||||
"""
|
||||
Returns the random number generator state of the specified GPU as a ByteTensor.
|
||||
"""
|
||||
return torch.cuda.get_rng_state(device)
|
||||
|
||||
def get_rng_state_all(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of ByteTensor representing the random number states of all devices.
|
||||
"""
|
||||
return torch.cuda.get_rng_state_all()
|
||||
|
||||
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "cuda") -> None:
|
||||
"""
|
||||
Sets the random number generator state of the specified GPU.
|
||||
"""
|
||||
torch.cuda.set_rng_state(new_state, device)
|
||||
|
||||
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
|
||||
"""
|
||||
Sets the random number generator state of all devices.
|
||||
"""
|
||||
torch.cuda.set_rng_state_all(new_states)
|
||||
|
||||
def manual_seed(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers for the current GPU.
|
||||
"""
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
def manual_seed_all(self, seed: int) -> None:
|
||||
"""
|
||||
Set the random seed for the all processes.
|
||||
"""
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def seed(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number for the current GPU.
|
||||
"""
|
||||
torch.cuda.seed()
|
||||
|
||||
def seed_all(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number on all GPUs.
|
||||
"""
|
||||
torch.cuda.seed_all()
|
||||
|
||||
def initial_seed(self) -> int:
|
||||
"""
|
||||
Returns the current random seed of the current GPU.
|
||||
"""
|
||||
return torch.cuda.initial_seed()
|
||||
|
||||
# =======================
|
||||
# memory management APIs
|
||||
# =======================
|
||||
|
||||
def empty_cache(self) -> None:
|
||||
"""
|
||||
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def memory_stats(self, device=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of CUDA memory allocator statistics for a given device.
|
||||
"""
|
||||
return torch.cuda.memory_stats(device=device)
|
||||
|
||||
def memory_summary(self, device=None, abbreviated=False) -> str:
|
||||
"""
|
||||
Returns a human-readable printout of the current memory allocator statistics for a given device.
|
||||
"""
|
||||
return torch.cuda.memory_summary(device=device, abbreviated=abbreviated)
|
||||
|
||||
def memory_snapshot(self):
|
||||
"""
|
||||
Returns a snapshot of the CUDA memory allocator state across all devices.
|
||||
"""
|
||||
return torch.cuda.memory_snapshot()
|
||||
|
||||
def memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
return torch.cuda.memory_allocated(device=device)
|
||||
|
||||
def max_memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
return torch.cuda.max_memory_allocated(device=device)
|
||||
|
||||
def reset_max_memory_allocated(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
|
||||
"""
|
||||
torch.cuda.reset_max_memory_allocated(device=device)
|
||||
|
||||
def reset_max_memory_cached(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
|
||||
"""
|
||||
torch.cuda.reset_max_memory_cached(device=device)
|
||||
|
||||
def memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return torch.cuda.memory_reserved(device=device)
|
||||
|
||||
def max_memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return torch.cuda.max_memory_reserved(device=device)
|
||||
|
||||
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
|
||||
"""
|
||||
Set memory fraction for a process.
|
||||
"""
|
||||
torch.cuda.set_per_process_memory_fraction(fraction, device=device)
|
||||
|
||||
def reset_peak_memory_stats(self, device=None) -> None:
|
||||
"""
|
||||
Resets the "peak" stats tracked by the CUDA memory allocator.
|
||||
"""
|
||||
torch.cuda.reset_peak_memory_stats(device=device)
|
||||
|
||||
# =======================
|
||||
# streams and events APIs
|
||||
# =======================
|
||||
|
||||
def Stream(self, device=None, priority=0, **kwargs):
|
||||
"""
|
||||
A CUDA stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See cuda-semantics for details.
|
||||
"""
|
||||
return torch.cuda.Stream(device, priority, **kwargs)
|
||||
|
||||
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||
"""
|
||||
CUDA events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize CUDA streams.
|
||||
"""
|
||||
return torch.cuda.Event(enable_timing, blocking, interprocess)
|
||||
|
||||
def current_stream(self, device=None):
|
||||
"""
|
||||
Returns the currently selected Stream for a given device.
|
||||
"""
|
||||
return torch.cuda.current_stream(device)
|
||||
|
||||
def default_stream(self, device=None):
|
||||
"""
|
||||
Returns the default Stream for a given device.
|
||||
"""
|
||||
return torch.cuda.default_stream(device)
|
||||
|
||||
def set_stream(self, stream_):
|
||||
"""
|
||||
Sets the current stream.This is a wrapper API to set the stream.
|
||||
"""
|
||||
torch.cuda.set_stream(stream_)
|
||||
|
||||
def stream(self, stream_):
|
||||
"""
|
||||
Wrapper around the Context-manager StreamContext that selects a given stream.
|
||||
"""
|
||||
return torch.cuda.stream(stream_)
|
||||
|
||||
# =======================
|
||||
# amp APIs
|
||||
# =======================
|
||||
def autocast(
|
||||
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Return autocast function
|
||||
"""
|
||||
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
||||
|
|
|
@ -1,13 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from typing import Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .base_accelerator import BaseAccelerator
|
||||
|
||||
IS_NPU_AVAILABLE = False
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
|
||||
IS_NPU_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
@ -26,16 +30,26 @@ class NpuAccelerator(BaseAccelerator):
|
|||
# =======================
|
||||
# device APIs
|
||||
# =======================
|
||||
def get_current_device(self) -> torch.device:
|
||||
"""
|
||||
Return the current device.
|
||||
"""
|
||||
return torch.device(f"npu:{torch.npu.current_device()}")
|
||||
|
||||
def current_device(self) -> int:
|
||||
"""
|
||||
Return the current device index.
|
||||
"""
|
||||
return torch.npu.current_device()
|
||||
|
||||
def set_device(self, device: Union[torch.device, int]) -> None:
|
||||
def set_device(self, device: Optional[Union[torch.device, int]] = None) -> None:
|
||||
"""
|
||||
Bind the current process to a device.
|
||||
"""
|
||||
if device is None:
|
||||
if not dist.is_initialized():
|
||||
raise RuntimeError("Cannot get current device when distributed is not initialized.")
|
||||
device = dist.get_rank() % self.device_count()
|
||||
torch.npu.set_device(device)
|
||||
|
||||
def get_device_name(self, device: Union[torch.device, int]) -> str:
|
||||
|
@ -61,3 +75,211 @@ class NpuAccelerator(BaseAccelerator):
|
|||
Return the number of devices on the machine.
|
||||
"""
|
||||
return torch.npu.device_count()
|
||||
|
||||
def get_device_capability(self, device=None) -> Tuple[int, int]:
|
||||
"""
|
||||
Gets the npu capability of a device.
|
||||
"""
|
||||
return torch.npu.get_device_capability(device)
|
||||
|
||||
def get_device_name(self, device=None) -> str:
|
||||
"""
|
||||
Gets the name of a device.
|
||||
"""
|
||||
return torch.npu.get_device_name(device)
|
||||
|
||||
def get_device_properties(self, device):
|
||||
"""
|
||||
Gets the properties of a device.
|
||||
"""
|
||||
return torch.npu.get_device_properties(device)
|
||||
|
||||
def utilization(self, device=None) -> int:
|
||||
"""
|
||||
Returns the percent of time over the past sample period during which one or more kernels was executing on the GPU as given by nvidia-smi
|
||||
"""
|
||||
return torch.npu.utilization(device)
|
||||
|
||||
# =======================
|
||||
# random number generator APIs
|
||||
# =======================
|
||||
def get_rng_state(self, device="npu") -> torch.Tensor:
|
||||
"""
|
||||
Returns the random number generator state of the specified GPU as a ByteTensor.
|
||||
"""
|
||||
return torch.npu.get_rng_state(device)
|
||||
|
||||
def get_rng_state_all(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of ByteTensor representing the random number states of all devices.
|
||||
"""
|
||||
return torch.npu.get_rng_state_all()
|
||||
|
||||
def set_rng_state(self, new_state: torch.ByteTensor, device: str = "npu") -> None:
|
||||
"""
|
||||
Sets the random number generator state of the specified GPU.
|
||||
"""
|
||||
torch.npu.set_rng_state(new_state, device)
|
||||
|
||||
def set_rng_state_all(self, new_states: List[torch.ByteTensor]) -> None:
|
||||
"""
|
||||
Sets the random number generator state of all devices.
|
||||
"""
|
||||
torch.npu.set_rng_state_all(new_states)
|
||||
|
||||
def manual_seed(self, seed: int) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers for the current GPU.
|
||||
"""
|
||||
torch.npu.manual_seed(seed)
|
||||
|
||||
def manual_seed_all(self, seed: int) -> None:
|
||||
"""
|
||||
Set the random seed for the all processes.
|
||||
"""
|
||||
torch.npu.manual_seed_all(seed)
|
||||
|
||||
def seed(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number for the current GPU.
|
||||
"""
|
||||
torch.npu.seed()
|
||||
|
||||
def seed_all(self) -> None:
|
||||
"""
|
||||
Sets the seed for generating random numbers to a random number on all GPUs.
|
||||
"""
|
||||
torch.npu.seed_all()
|
||||
|
||||
def initial_seed(self) -> int:
|
||||
"""
|
||||
Returns the current random seed of the current GPU.
|
||||
"""
|
||||
return torch.npu.initial_seed()
|
||||
|
||||
# =======================
|
||||
# memory management APIs
|
||||
# =======================
|
||||
|
||||
def empty_cache(self) -> None:
|
||||
"""
|
||||
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible in nvidia-smi.
|
||||
"""
|
||||
torch.npu.empty_cache()
|
||||
|
||||
def memory_stats(self, device=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns a dictionary of npu memory allocator statistics for a given device.
|
||||
"""
|
||||
return torch.npu.memory_stats(device=device)
|
||||
|
||||
def memory_summary(self, device=None, abbreviated=False) -> str:
|
||||
"""
|
||||
Returns a human-readable printout of the current memory allocator statistics for a given device.
|
||||
"""
|
||||
return torch.npu.memory_summary(device=device, abbreviated=abbreviated)
|
||||
|
||||
def memory_snapshot(self):
|
||||
"""
|
||||
Returns a snapshot of the npu memory allocator state across all devices.
|
||||
"""
|
||||
return torch.npu.memory_snapshot()
|
||||
|
||||
def memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
return torch.npu.memory_allocated(device=device)
|
||||
|
||||
def max_memory_allocated(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory occupied by tensors in bytes for a given device.
|
||||
"""
|
||||
return torch.npu.max_memory_allocated(device=device)
|
||||
|
||||
def reset_max_memory_allocated(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory occupied by tensors for a given device.
|
||||
"""
|
||||
torch.npu.reset_max_memory_allocated(device=device)
|
||||
|
||||
def reset_max_memory_cached(self, device=None) -> None:
|
||||
"""
|
||||
Resets the starting point in tracking maximum GPU memory managed by the caching allocator for a given device.
|
||||
"""
|
||||
torch.npu.reset_max_memory_cached(device=device)
|
||||
|
||||
def memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the current GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return torch.npu.memory_reserved(device=device)
|
||||
|
||||
def max_memory_reserved(self, device=None) -> int:
|
||||
"""
|
||||
Returns the maximum GPU memory managed by the caching allocator in bytes for a given device.
|
||||
"""
|
||||
return torch.npu.max_memory_reserved(device=device)
|
||||
|
||||
def set_per_process_memory_fraction(self, fraction: float, device=None) -> None:
|
||||
"""
|
||||
Set memory fraction for a process.
|
||||
"""
|
||||
torch.npu.set_per_process_memory_fraction(fraction, device=device)
|
||||
|
||||
def reset_peak_memory_stats(self, device=None) -> None:
|
||||
"""
|
||||
Resets the "peak" stats tracked by the npu memory allocator.
|
||||
"""
|
||||
torch.npu.reset_peak_memory_stats(device=device)
|
||||
|
||||
# =======================
|
||||
# streams and events APIs
|
||||
# =======================
|
||||
|
||||
def Stream(self, device=None, priority=0, **kwargs):
|
||||
"""
|
||||
A npu stream is a linear sequence of execution that belongs to a specific device, independent from other streams. See npu-semantics for details.
|
||||
"""
|
||||
return torch.npu.Stream(device, priority, **kwargs)
|
||||
|
||||
def Event(self, enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||
"""
|
||||
npu events are synchronization markers that can be used to monitor the device's progress, to accurately measure timing, and to synchronize npu streams.
|
||||
"""
|
||||
return torch.npu.Event(enable_timing, blocking, interprocess)
|
||||
|
||||
def current_stream(self, device=None):
|
||||
"""
|
||||
Returns the currently selected Stream for a given device.
|
||||
"""
|
||||
return torch.npu.current_stream(device)
|
||||
|
||||
def default_stream(self, device=None):
|
||||
"""
|
||||
Returns the default Stream for a given device.
|
||||
"""
|
||||
return torch.npu.default_stream(device)
|
||||
|
||||
def set_stream(self, stream_):
|
||||
"""
|
||||
Sets the current stream.This is a wrapper API to set the stream.
|
||||
"""
|
||||
torch.npu.set_stream(stream_)
|
||||
|
||||
def stream(self, stream_):
|
||||
"""
|
||||
Wrapper around the Context-manager StreamContext that selects a given stream.
|
||||
"""
|
||||
return torch.npu.stream(stream_)
|
||||
|
||||
# =======================
|
||||
# amp APIs
|
||||
# =======================
|
||||
def autocast(
|
||||
self, enabled: bool = True, dtype: torch.dtype = torch.float16, cache_enabled: bool = True
|
||||
) -> Callable:
|
||||
"""
|
||||
Return autocast function
|
||||
"""
|
||||
return torch.npu.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
||||
|
|
|
@ -7,8 +7,8 @@ from typing import Dict
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
__all__ = ["BaseGradScaler"]
|
||||
|
||||
|
@ -23,7 +23,7 @@ class BaseGradScaler(ABC):
|
|||
|
||||
def __init__(self, initial_scale: float, verbose: bool):
|
||||
assert initial_scale > 0
|
||||
self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
|
||||
self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
|
||||
self._verbose = verbose
|
||||
|
||||
if self._verbose:
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import Optional
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.utils.device import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .base_grad_scaler import BaseGradScaler
|
||||
|
||||
|
@ -37,14 +37,20 @@ class DynamicGradScaler(BaseGradScaler):
|
|||
hysteresis: int = 2,
|
||||
verbose: bool = False,
|
||||
):
|
||||
a = get_accelerator()
|
||||
a.device_count()
|
||||
super().__init__(initial_scale, verbose)
|
||||
if min_scale:
|
||||
self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
|
||||
self._min_scale = torch.tensor(
|
||||
[min_scale], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
else:
|
||||
self._min_scale = None
|
||||
|
||||
if max_scale:
|
||||
self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
|
||||
self._max_scale = torch.tensor(
|
||||
[max_scale], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
else:
|
||||
self._max_scale = None
|
||||
|
||||
|
@ -117,7 +123,7 @@ class DynamicGradScaler(BaseGradScaler):
|
|||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._scale = state_dict["scale"].to(get_current_device())
|
||||
self._scale = state_dict["scale"].to(get_accelerator().get_current_device())
|
||||
self._growth_factor = state_dict["growth_factor"]
|
||||
self._backoff_factor = state_dict["backoff_factor"]
|
||||
self._hysteresis = state_dict["hysteresis"]
|
||||
|
|
|
@ -5,8 +5,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import MixedPrecisionMixin
|
||||
|
||||
|
@ -40,7 +40,7 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
|
|||
max_scale=max_scale,
|
||||
)
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
|
||||
|
||||
@property
|
||||
def loss_scale(self) -> float:
|
||||
|
|
|
@ -4,10 +4,10 @@ from typing import Dict, Tuple
|
|||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base_offload_module import BaseOffloadModule
|
||||
from .region import Region
|
||||
|
@ -79,7 +79,9 @@ class AMPOptimizer(OptimizerWrapper):
|
|||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
self._found_overflow: torch.Tensor = torch.zeros(
|
||||
1, dtype=torch.int64, device=get_accelerator().get_current_device()
|
||||
)
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def _set_grad_ptr(self):
|
||||
|
|
|
@ -11,7 +11,7 @@ except:
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.utils.device import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .region import Region
|
||||
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
|
||||
|
@ -57,7 +57,10 @@ class Solver(ABC):
|
|||
if memory_budget > 0:
|
||||
self.memory_budget = memory_budget * self.error_factor
|
||||
else:
|
||||
self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor
|
||||
self.memory_budget = (
|
||||
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
|
||||
* self.error_factor
|
||||
)
|
||||
|
||||
self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
|
||||
self.comp_power: float = self._extract_computing_power()
|
||||
|
|
|
@ -5,8 +5,8 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils.device import autocast
|
||||
|
||||
from .mixed_precision_base import MixedPrecision
|
||||
|
||||
|
@ -89,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
|
|||
super().__init__(module)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with autocast():
|
||||
with get_accelerator().autocast():
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.accelerator import IS_NPU_AVAILABLE, get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_model_base_filenames,
|
||||
|
@ -24,8 +25,6 @@ from colossalai.checkpoint_io.utils import (
|
|||
from colossalai.cluster import DistCoordinator, ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.device import IS_NPU_AVAILABLE
|
||||
from colossalai.zero import GeminiDDP, GeminiOptimizer
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
|
@ -367,7 +366,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
assert placement_policy == "static", "NPU only supports static placement policy"
|
||||
self.gemini_config = dict(
|
||||
chunk_config_dict=chunk_config_dict,
|
||||
chunk_init_device=(chunk_init_device or get_current_device()),
|
||||
chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
|
||||
placement_policy=placement_policy,
|
||||
enable_gradient_accumulation=enable_gradient_accumulation,
|
||||
shard_param_frac=shard_param_frac,
|
||||
|
|
|
@ -18,6 +18,7 @@ from torch.utils._pytree import tree_map
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
|
@ -29,7 +30,6 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
|
|||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_distributed_tensor
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
|
@ -82,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
|
|||
self.mixed_precision = torch.bfloat16
|
||||
if self.mixed_precision is not None:
|
||||
module = module.to(self.mixed_precision)
|
||||
module = module.to(get_current_device())
|
||||
module = module.to(get_accelerator().get_current_device())
|
||||
|
||||
# setting input type cast when using mixed precision
|
||||
self.convert_fn = None
|
||||
|
@ -346,7 +346,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
if self.pp_size > 1:
|
||||
|
@ -385,7 +387,9 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
@ -543,7 +547,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
# so we need to calculate the norm of 'tp' and 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
|
||||
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
|
@ -586,7 +592,9 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
if self.tp_size > 1:
|
||||
# compute norm in tp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
|
||||
|
@ -798,7 +806,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
# so we only need to calculate the norm 'tp' of 'pp' gradients.
|
||||
total_norm = super()._compute_grad_norm(gradients, norm_type)
|
||||
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
|
||||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
if tp_size > 1:
|
||||
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
|
||||
|
@ -837,7 +847,9 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|||
|
||||
total_norm_exponentiated += grad_norm_exponentiated
|
||||
|
||||
total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
|
||||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float32
|
||||
)
|
||||
if dp_size > 1:
|
||||
# compute norm in dp process group
|
||||
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_optimizer_base_filenames,
|
||||
|
@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import (
|
|||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
|
@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
self.dtype = torch.bfloat16
|
||||
if self.dtype is not None:
|
||||
module = module.to(self.dtype)
|
||||
module = module.to(get_current_device())
|
||||
module = module.to(get_accelerator().get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
|
|
|
@ -6,12 +6,12 @@ import warnings
|
|||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.context import Config
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
|
||||
from colossalai.utils import set_seed
|
||||
|
||||
|
||||
def launch(
|
||||
|
@ -47,17 +47,18 @@ def launch(
|
|||
if rank == 0:
|
||||
warnings.warn("`config` is deprecated and will be removed soon.")
|
||||
|
||||
if IS_NPU_AVAILABLE and backend == "nccl":
|
||||
backend = "hccl"
|
||||
cur_accelerator = get_accelerator()
|
||||
|
||||
backend = cur_accelerator.communication_backend
|
||||
|
||||
# init default process group
|
||||
init_method = f"tcp://[{host}]:{port}"
|
||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
|
||||
# set cuda device
|
||||
if torch.cuda.is_available() or IS_NPU_AVAILABLE:
|
||||
# if local rank is not given, calculate automatically
|
||||
set_device(local_rank)
|
||||
if cur_accelerator.support_set_device:
|
||||
cur_accelerator.set_device(local_rank)
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from colossalai.utils.device import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class Unpad(torch.autograd.Function):
|
||||
|
@ -70,7 +70,9 @@ class SeqLenInfo:
|
|||
cu_seqlens: torch.Tensor = None
|
||||
|
||||
@staticmethod
|
||||
def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()):
|
||||
def materialize(
|
||||
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
|
||||
):
|
||||
if attn_mask is not None:
|
||||
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
|
||||
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .bias_dropout_add import bias_dropout_add_fused_train
|
||||
from .bias_gelu import bias_gelu_impl
|
||||
|
@ -46,11 +46,13 @@ def warmup_jit_fusion(
|
|||
):
|
||||
"""Compile JIT functions before the main training steps"""
|
||||
|
||||
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
|
||||
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
|
||||
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device())
|
||||
embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
|
||||
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
|
||||
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device())
|
||||
|
||||
x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device())
|
||||
x = torch.randint(
|
||||
vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device()
|
||||
)
|
||||
x = embed(x)
|
||||
y, y_bias = linear_1(x)
|
||||
z, z_bias = linear_2(y)
|
||||
|
@ -58,8 +60,8 @@ def warmup_jit_fusion(
|
|||
# prop and recomputation
|
||||
for bias_grad, input_grad in zip([True, True], [False, True]):
|
||||
for _ in range(10):
|
||||
bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device())
|
||||
input_ = torch.rand_like(y, dtype=dtype, device=get_current_device())
|
||||
bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
bias.requires_grad, input_.requires_grad = bias_grad, input_grad
|
||||
bias_gelu_impl(input_, bias)
|
||||
|
||||
|
@ -69,9 +71,9 @@ def warmup_jit_fusion(
|
|||
# prop and recomputation
|
||||
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
|
||||
for _ in range(10):
|
||||
input_ = torch.rand_like(z, dtype=dtype, device=get_current_device())
|
||||
residual = torch.rand_like(x, dtype=dtype, device=get_current_device())
|
||||
bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device())
|
||||
input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
input_.requires_grad = input_grad
|
||||
bias.requires_grad = bias_grad
|
||||
residual.requires_grad = residual_grad
|
||||
|
|
|
@ -1,18 +1,19 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.utils.device import autocast
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.legacy.utils import clip_grad_norm_fp32
|
||||
|
||||
from ._grad_scaler import GradScaler
|
||||
|
||||
autocast = get_accelerator().autocast
|
||||
|
||||
|
||||
class TorchAMPOptimizer(OptimizerWrapper):
|
||||
"""A wrapper class which integrate Pytorch AMP with an optimizer
|
||||
|
|
|
@ -8,9 +8,9 @@ from typing import List, Tuple, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
|
||||
|
||||
|
@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
|
|||
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
|
||||
if isinstance(recv_shapes, torch.Size):
|
||||
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
|
||||
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||
buffer_recv = torch.empty(
|
||||
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
return buffer_recv, recv_split
|
||||
buffer_recv = []
|
||||
for recv_shape in recv_shapes:
|
||||
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
|
||||
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype)
|
||||
tensor_recv = torch.empty(
|
||||
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
buffer_recv.append(tensor_recv)
|
||||
return buffer_recv, recv_split
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device, synchronize
|
||||
|
||||
|
||||
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
|
||||
|
@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
|
|||
current_rank = gpc.get_global_rank()
|
||||
|
||||
tensor_recv_prev = torch.empty(
|
||||
buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype
|
||||
buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype
|
||||
)
|
||||
|
||||
# send to next rank
|
||||
|
@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
|
|||
req.wait()
|
||||
|
||||
# To protect against race condition when using batch_isend_irecv().
|
||||
synchronize()
|
||||
get_accelerator().synchronize()
|
||||
|
||||
return tensor_recv_prev
|
||||
|
|
|
@ -3,9 +3,9 @@ from typing import List, Tuple, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
|||
if next_rank is None:
|
||||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
|
||||
if isinstance(obj, torch.Tensor):
|
||||
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
|
@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
|||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
|
||||
tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
|
||||
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_obj_nums, prev_rank)
|
||||
if recv_obj_nums.item() == 1:
|
||||
|
|
|
@ -6,8 +6,8 @@ from typing import Callable, Iterable
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class BaseSchedule(ABC):
|
||||
|
@ -29,12 +29,12 @@ class BaseSchedule(ABC):
|
|||
def _move_tensor(element):
|
||||
if torch.is_tensor(element):
|
||||
if not element.is_cuda:
|
||||
return element.to(get_current_device()).detach()
|
||||
return element.to(get_accelerator().get_current_device()).detach()
|
||||
return element
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.to(get_current_device())
|
||||
data = data.to(get_accelerator().get_current_device())
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data_to_return = []
|
||||
for element in data:
|
||||
|
|
|
@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union
|
|||
import torch.cuda
|
||||
|
||||
import colossalai.legacy.communication as comm
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
|
@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
output_objs = []
|
||||
return_tensors = []
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
# Used for tensor meta information communication
|
||||
|
@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
if not forward_only:
|
||||
output_obj_grads = [[] for _ in range(len(model))]
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
|
|
|
@ -6,10 +6,10 @@ from typing import Iterable, Tuple
|
|||
import torch.cuda
|
||||
|
||||
import colossalai.legacy.communication.p2p_v2 as comm
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.engine import Engine
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ._pipeline_schedule import PipelineSchedule
|
||||
|
||||
|
@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule):
|
|||
output_objs = []
|
||||
return_tensors = []
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.context import Config, ConfigException
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
|
||||
|
@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence
|
|||
from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
|
||||
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_default_parser():
|
||||
|
@ -309,9 +309,9 @@ def initialize(
|
|||
else:
|
||||
if isinstance(model, nn.Module):
|
||||
# first sync model across dp ranks
|
||||
model.to(get_current_device())
|
||||
model.to(get_accelerator().get_current_device())
|
||||
elif isinstance(model, Callable):
|
||||
model = model().to(get_current_device())
|
||||
model = model().to(get_accelerator().get_current_device())
|
||||
|
||||
# optimizer maybe a optimizer_cls
|
||||
if isinstance(optimizer, Callable):
|
||||
|
|
|
@ -3,8 +3,8 @@ from typing import Callable
|
|||
|
||||
from torch import dtype, nn
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.nn import init
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
|
||||
from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
|
||||
|
@ -83,7 +83,7 @@ class Embedding(ColossalaiModule):
|
|||
embed = (
|
||||
nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs)
|
||||
.to(dtype)
|
||||
.to(get_current_device())
|
||||
.to(get_accelerator().get_current_device())
|
||||
)
|
||||
weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||||
elif num_embeddings <= vocab_parallel_limit:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from torch import nn
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from ..parallel_1d import LayerNorm1D
|
||||
from ..parallel_2d import LayerNorm2D
|
||||
|
@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule):
|
|||
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
|
||||
tensor_parallel = get_tensor_parallel_mode()
|
||||
if tensor_parallel is None:
|
||||
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
|
||||
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device())
|
||||
else:
|
||||
norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
|
||||
super().__init__(norm)
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.legacy.communication import broadcast
|
||||
from colossalai.legacy.context import ParallelMode, seed
|
||||
|
@ -22,7 +23,6 @@ from colossalai.legacy.utils.checkpointing import (
|
|||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..colossalai_layer._utils import ColossalaiModule
|
||||
|
@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer):
|
|||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
|
@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
|
@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer):
|
|||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||
|
||||
if bias:
|
||||
|
@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer):
|
|||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
|
@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer):
|
|||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer):
|
|||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
|
|
@ -5,10 +5,10 @@ import torch.distributed as dist
|
|||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def matmul_2d(
|
||||
|
@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[0])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[-1], B.shape[-1])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import broadcast
|
||||
from colossalai.legacy.context import ParallelMode, seed
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import (
|
|||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
|
@ -82,7 +82,7 @@ class Linear2D(ParallelLayer):
|
|||
self.hidden_size_per_partition = divide(self.out_features, self.summa_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
|
||||
)
|
||||
|
@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer):
|
|||
self.partitioned_partition = divide(normalized_shape, self.summa_dim**2)
|
||||
|
||||
# create parameters
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
|
||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
if bias:
|
||||
|
@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.cls_token = Parameter(
|
||||
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
self.pos_embed = Parameter(
|
||||
torch.zeros(
|
||||
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
|
||||
(1, self.num_patches + 1, self.embed_size_per_partition),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer):
|
|||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer):
|
|||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
self.num_classes,
|
||||
self.input_size_per_partition,
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer):
|
|||
self.output_size_per_partition = divide(num_classes, self.summa_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
|
|
|
@ -5,10 +5,10 @@ import torch.distributed as dist
|
|||
from torch import Tensor
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_parallel_group(parallel_mode: ParallelMode):
|
||||
|
@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[0])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
|
|||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[-1], B.shape[-1])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
|
||||
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
|
@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function):
|
|||
if row_rank == 0:
|
||||
bias_temp = bias.clone()
|
||||
else:
|
||||
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
|
||||
bias_temp = torch.zeros(
|
||||
output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
src_rank = (
|
||||
col_rank
|
||||
+ dep_rank * tesseract_dim**2
|
||||
|
@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function):
|
|||
@custom_bwd
|
||||
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
|
||||
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
|
||||
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device())
|
||||
dist.all_gather(
|
||||
list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode)
|
||||
)
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import broadcast
|
||||
from colossalai.legacy.context import ParallelMode, seed
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
|
@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import (
|
|||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ..base_layer import ParallelLayer
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
|
@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer):
|
|||
self.hidden_size_per_partition = divide(out_features, self.tesseract_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
|
||||
)
|
||||
|
@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
|
||||
|
||||
# create parameters
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
|
||||
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
|
||||
if bias:
|
||||
|
@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.cls_token = Parameter(
|
||||
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
self.pos_embed = Parameter(
|
||||
torch.zeros(
|
||||
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
|
||||
(1, self.num_patches + 1, self.embed_size_per_partition),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer):
|
|||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer):
|
|||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
self.num_classes,
|
||||
self.input_size_per_partition,
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
|
|||
self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim)
|
||||
|
||||
# create weight, shape: [k/q, h/q]
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
if weight is not None:
|
||||
self.weight = weight
|
||||
self.has_weight = False
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import all_reduce, broadcast
|
||||
from colossalai.legacy.constants import (
|
||||
INPUT_GROUP_3D,
|
||||
|
@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import (
|
|||
partition_tensor_parallel_state_dict,
|
||||
)
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||
from ._operation import (
|
||||
|
@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer):
|
|||
self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
@ -202,13 +204,15 @@ class Linear3D(ParallelLayer):
|
|||
torch.empty(
|
||||
self.in_features_per_partition,
|
||||
self.out_features_per_partition,
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer):
|
|||
self.has_weight = False
|
||||
else:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
self.num_classes,
|
||||
self.in_features_per_partition,
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
torch.empty(
|
||||
self.out_features_per_partition,
|
||||
self.in_features_per_partition,
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(
|
||||
(embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype
|
||||
(embed_size_per_partition, in_chans, *self.patch_size),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype))
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
(1, self.num_patches + 1, embed_size_per_partition),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
|
@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer):
|
|||
self.embed_kwargs = kwargs
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
|
|||
self.weight = Parameter(
|
||||
torch.empty(
|
||||
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -5,11 +5,11 @@ import torch
|
|||
from torch import distributed as dist
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import ring_forward
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class RingQK(torch.autograd.Function):
|
||||
|
@ -30,7 +30,7 @@ class RingQK(torch.autograd.Function):
|
|||
sub_seq_length,
|
||||
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
|
||||
dtype=sub_q.dtype,
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
)
|
||||
|
||||
# compute local QK^T
|
||||
|
@ -71,7 +71,7 @@ class RingQK(torch.autograd.Function):
|
|||
grad_q = torch.zeros_like(
|
||||
sub_q,
|
||||
dtype=sub_q.dtype,
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
)
|
||||
|
||||
# compute with local sub_k
|
||||
|
@ -105,7 +105,7 @@ class RingAV(torch.autograd.Function):
|
|||
batch_size * num_attention_heads,
|
||||
sub_seq_length,
|
||||
attention_head_size,
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=attention_score.dtype,
|
||||
)
|
||||
|
||||
|
@ -142,7 +142,9 @@ class RingAV(torch.autograd.Function):
|
|||
grad_v /= local_world_size
|
||||
|
||||
# calculate gradient for attention score
|
||||
grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device())
|
||||
grad_attention_score = torch.zeros_like(
|
||||
attention_scores, dtype=grad_output.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
|
||||
# compute with local sub_k
|
||||
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))
|
||||
|
|
|
@ -7,10 +7,10 @@ from torch import Tensor
|
|||
from torch import nn as nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context import seed
|
||||
from colossalai.legacy.registry import LAYERS
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ..utils import to_2tuple
|
||||
|
||||
|
@ -173,12 +173,18 @@ class VanillaPatchEmbedding(nn.Module):
|
|||
self.flatten = flatten
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
(embed_size, in_chans, *self.patch_size), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
self.bias = nn.Parameter(torch.empty(embed_size, device=get_accelerator().get_current_device(), dtype=dtype))
|
||||
self.cls_token = nn.Parameter(
|
||||
torch.zeros((1, 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
|
||||
self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)
|
||||
torch.zeros(
|
||||
(1, self.num_patches + 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
|
||||
|
@ -242,11 +248,15 @@ class VanillaClassifier(nn.Module):
|
|||
self.has_weight = False
|
||||
else:
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)
|
||||
torch.empty(
|
||||
self.num_classes, self.in_features, device=get_accelerator().get_current_device(), dtype=dtype
|
||||
)
|
||||
)
|
||||
self.has_weight = True
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
|
||||
self.bias = nn.Parameter(
|
||||
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -287,7 +297,7 @@ class VanillaLayerNorm(nn.Module):
|
|||
self.normalized_shape = (normalized_shape,)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
|
||||
if bias:
|
||||
|
@ -333,7 +343,7 @@ class VanillaLinear(nn.Module):
|
|||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.skip_bias_add = skip_bias_add
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
|
|
|
@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||
from torch.nn.functional import cross_entropy
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
|
||||
from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization
|
||||
from colossalai.legacy.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
|
@ -118,7 +118,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function):
|
|||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
|
||||
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
|
|
|
@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||
from torch.nn.functional import cross_entropy
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
|
||||
from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
|
||||
from colossalai.legacy.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
|
@ -112,7 +112,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
|
|||
grad_2d = grad_input.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
|
||||
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
|
|
|
@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
|||
from torch.nn.functional import cross_entropy
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
|
||||
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from colossalai.legacy.registry import LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
|
@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
|
|||
target_mask = (targets < vocab_start) | (targets > vocab_end)
|
||||
masked_target = targets.clone() - vocab_start
|
||||
masked_target[target_mask] = 0
|
||||
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device())
|
||||
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device())
|
||||
predicted_logits = logits[arange_1d, masked_target]
|
||||
predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
|
@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
|
|||
grad_2d = input_grad.view(-1, partition_vocab_size)
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
|
||||
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
|
||||
input_grad.mul_(output_grad.unsqueeze(dim=-1))
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@ from typing import Callable
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.communication import all_reduce
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.registry import HOOKS
|
||||
from colossalai.legacy.utils import is_no_pp_or_last_stage
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from ._base_hook import BaseHook
|
||||
from ._commons_ import _format_number
|
||||
|
@ -82,8 +82,8 @@ class LossMetric(Metric):
|
|||
|
||||
def __init__(self, epoch_only):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.last_step_loss = torch.zeros(1, device=get_current_device())
|
||||
self.accum_loss = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.count = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
|
@ -164,10 +164,10 @@ class AccuracyMetric(Metric):
|
|||
def __init__(self, epoch_only: bool, accuracy_func: Callable):
|
||||
super().__init__(epoch_only=epoch_only)
|
||||
self.acc = accuracy_func
|
||||
self.last_step_sum = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_correct = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_sum = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_correct = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
|
||||
def reset(self) -> None:
|
||||
self.last_step_sum.zero_()
|
||||
|
@ -320,10 +320,10 @@ class ThroughputMetric(Metric):
|
|||
super().__init__(epoch_only=epoch_only)
|
||||
self.ignored_steps = ignored_steps
|
||||
self.cur_steps = 0
|
||||
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
|
||||
self.last_step_used_time = torch.zeros(1, device=get_current_device())
|
||||
self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
self._tflop_per_step = tflop_per_step
|
||||
self._use_local = use_local
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@ import weakref
|
|||
import torch
|
||||
from torch.utils.checkpoint import check_backward_validity, detach_variable
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
|
||||
from colossalai.utils.device import autocast, get_current_device
|
||||
|
||||
|
||||
def copy_to_device(obj, device):
|
||||
|
@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
ctx.activation_offload = activation_offload
|
||||
ctx.device = get_current_device()
|
||||
ctx.device = get_accelerator().get_current_device()
|
||||
|
||||
# preserve rng states
|
||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||
|
@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||
inputs[idx] = tensors[i]
|
||||
detached_inputs = detach_variable(tuple(inputs))
|
||||
if ctx.had_autocast_in_fwd:
|
||||
with torch.enable_grad(), autocast():
|
||||
with torch.enable_grad(), get_accelerator().autocast()():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
else:
|
||||
with torch.enable_grad():
|
||||
|
@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
|||
|
||||
# rerun forward, the inner_pack will store all the activations in storage
|
||||
if has_autocast_in_fwd:
|
||||
with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks(
|
||||
with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks(
|
||||
inner_pack, inner_unpack
|
||||
):
|
||||
_unused = function(*args)
|
||||
|
@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
|
|||
|
||||
# get device if we need to offload the activation
|
||||
if activation_offload:
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
|
||||
# run function with pack and unpack as saved_tensors_hooks
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
|
|
|
@ -6,9 +6,9 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
_GLOBAL_CUDA_MEM_FRACTION = 1.0
|
||||
_GLOBAL_CPU_MEM_CAPACITY = -1
|
||||
|
@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
|
|||
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
|
||||
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
|
||||
if device.type == "cuda":
|
||||
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
|
||||
return (
|
||||
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
|
||||
* _GLOBAL_CUDA_MEM_FRACTION
|
||||
)
|
||||
|
||||
|
||||
def colo_device_memory_used(device: torch.device) -> int:
|
||||
|
@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
|
|||
return
|
||||
global _GLOBAL_CUDA_MEM_FRACTION
|
||||
_GLOBAL_CUDA_MEM_FRACTION = ratio
|
||||
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
|
||||
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device())
|
||||
|
||||
|
||||
def colo_set_cpu_memory_capacity(size: int) -> None:
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.distributed as dist
|
|||
from torch.autograd.profiler import profile
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
|
||||
|
||||
|
@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler):
|
|||
|
||||
assert current_comm_event is not None, "dist op has not been found"
|
||||
|
||||
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
|
||||
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device())
|
||||
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
|
||||
current_comm_event.self_cuda_time = buffer.item()
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import types
|
|||
from time import time
|
||||
from typing import List
|
||||
|
||||
from colossalai.utils.device import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .stateful_tensor import StatefulTensor, TensorState
|
||||
from .tensor_placement_policy import TensorPlacementPolicy
|
||||
|
@ -69,7 +69,7 @@ class StatefulTensorMgr(object):
|
|||
# move COMPUTE tensors to CUDA
|
||||
self._cpu_gpu_move_volume += cuda_demand
|
||||
for t in move_to_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||
colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device())
|
||||
|
||||
@property
|
||||
def cpu_gpu_move_volume(self):
|
||||
|
|
|
@ -5,8 +5,8 @@ from typing import List, Optional, Type
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
|
||||
|
||||
from .stateful_tensor import StatefulTensor
|
||||
|
@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
|
|||
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available"
|
||||
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
|
||||
super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||
return 0, 0
|
||||
|
@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
|||
int: the volume of memory that is evicted
|
||||
"""
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
|
||||
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"]
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
|
|
|
@ -4,8 +4,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors as flatten
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
|
@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
|||
rank = dist.get_rank(process_group)
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
||||
buffer_list.append(
|
||||
flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device())
|
||||
)
|
||||
else:
|
||||
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
||||
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device()))
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
|
||||
# Move to target device before splitting buffer
|
||||
# Ensure we utilize maximum PCIE bandwidth
|
||||
|
|
|
@ -3,11 +3,11 @@ from typing import List, Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.legacy.zero.shard_utils.commons import get_shard
|
||||
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
|
@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||
if t.is_sharded:
|
||||
return
|
||||
if t.payload.device.type == "cuda":
|
||||
assert t.payload.device == get_current_device(), (
|
||||
assert t.payload.device == get_accelerator().get_current_device(), (
|
||||
f"shard tensor on cuda device index {t.payload.device.index},"
|
||||
f" but current cuda device is {get_current_device()}"
|
||||
f" but current cuda device is {get_accelerator().get_current_device()}"
|
||||
)
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.payload_reset(sharded_payload)
|
||||
|
@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
|
||||
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
|
||||
buffer = torch.empty(
|
||||
payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
|
||||
buffer_list[rank].copy_(t.payload)
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
|
@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c
|
|||
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import disposable, get_current_device
|
||||
from colossalai.utils import disposable
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
|
||||
|
||||
from ._utils import (
|
||||
|
@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module):
|
|||
self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0])
|
||||
if gpc.get_global_rank() == 0:
|
||||
with open(filename, "w+") as f:
|
||||
f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n")
|
||||
f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n")
|
||||
f.write(
|
||||
f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n"
|
||||
)
|
||||
f.write(
|
||||
f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n"
|
||||
)
|
||||
f.write("CUDA model data (GB)\n")
|
||||
f.write("\n")
|
||||
f.write("CUDA non model data (GB)\n")
|
||||
|
@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module):
|
|||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = (
|
||||
colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
|
||||
colo_device_memory_capacity(get_accelerator().get_current_device())
|
||||
- self._memstats_collector._memstats.max_overall_cuda
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -3,13 +3,13 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.registry import OPHOOKS
|
||||
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
|
||||
from colossalai.legacy.zero.gemini.stateful_tensor import TensorState
|
||||
from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook):
|
|||
self.process_group = process_group
|
||||
|
||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||
self.computing_device = get_current_device()
|
||||
self.computing_device = get_accelerator().get_current_device()
|
||||
|
||||
self._memstarts_collector = memstarts_collector
|
||||
self._stateful_tensor_mgr = stateful_tensor_mgr
|
||||
|
|
|
@ -8,9 +8,9 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe._operation import moe_cumsum
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class MoeRouter(nn.Module, ABC):
|
||||
|
@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
|
|||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
k_value: int,
|
||||
capacity_factor_train: float,
|
||||
capacity_factor_eval: float,
|
||||
min_capacity: int,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True,
|
||||
use_kernel: bool = False):
|
||||
use_kernel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.k_value = k_value
|
||||
self.capacity_factor_train = capacity_factor_train
|
||||
|
@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
|
|||
if router_probs.dim() == expert_indices.dim() == 2:
|
||||
router_probs = router_probs.unsqueeze(0)
|
||||
expert_indices = expert_indices.unsqueeze(0)
|
||||
assert router_probs.dim() == expert_indices.dim() == 3, \
|
||||
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
||||
assert (
|
||||
router_probs.dim() == expert_indices.dim() == 3
|
||||
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
|
||||
|
||||
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
|
||||
expert_mask = F.one_hot(expert_indices, num_experts)
|
||||
|
@ -122,25 +125,29 @@ class Top1Router(MoeRouter):
|
|||
drop_tks (bool, optional): Whether drops tokens in evaluation
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
select_policy: str = "first",
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=1,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=1,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
self.select_policy = select_policy
|
||||
assert select_policy in {"first", "random"}
|
||||
if select_policy == "random":
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(0.0, device=get_current_device()),
|
||||
high=torch.tensor(1.0, device=get_current_device())
|
||||
low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
|
||||
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
|
||||
).rsample
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
|
@ -216,18 +223,22 @@ class Top2Router(MoeRouter):
|
|||
drop_tks (bool, optional): Whether drops tokens in evaluation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(k_value=2,
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
k_value=2,
|
||||
capacity_factor_train=capacity_factor_train,
|
||||
capacity_factor_eval=capacity_factor_eval,
|
||||
min_capacity=min_capacity,
|
||||
noisy_func=noisy_func,
|
||||
drop_tks=drop_tks)
|
||||
drop_tks=drop_tks,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
|
||||
"""
|
||||
|
@ -255,7 +266,7 @@ class Top2Router(MoeRouter):
|
|||
top2_idx = torch.argmax(logits_except1, dim=-1)
|
||||
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
|
||||
|
||||
cmask = (mask1 + mask2) # loss: [s, e]
|
||||
cmask = mask1 + mask2 # loss: [s, e]
|
||||
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
|
||||
|
||||
# calculate loss
|
||||
|
@ -336,15 +347,18 @@ class TopKRouter(MoeRouter):
|
|||
oversubscribed / reach capacity.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
num_selected_experts: int,
|
||||
capacity_factor_train: float = 1.25,
|
||||
capacity_factor_eval: float = 2.0,
|
||||
min_capacity: int = 4,
|
||||
noisy_func: Optional[Callable] = None,
|
||||
drop_tks: bool = True):
|
||||
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
|
||||
drop_tks)
|
||||
drop_tks: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -410,7 +424,7 @@ class TopKRouter(MoeRouter):
|
|||
# The combine array will be used for combining expert outputs, scaled by the
|
||||
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
|
||||
# expert_capacity].
|
||||
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
|
||||
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
|
||||
|
||||
return combine_array, dispatch_mask
|
||||
|
||||
|
|
|
@ -7,13 +7,12 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
|
||||
def half(self, memory_format=None):
|
||||
return self.data.clone()
|
||||
|
||||
|
@ -30,8 +29,8 @@ class NormalNoiseGenerator:
|
|||
|
||||
def __init__(self, num_experts: int):
|
||||
self.normal = torch.distributions.normal.Normal(
|
||||
loc=torch.tensor(0.0, device=get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
|
||||
loc=torch.tensor(0.0, device=get_accelerator().get_current_device()),
|
||||
scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
|
@ -52,8 +51,8 @@ class UniformNoiseGenerator:
|
|||
|
||||
def __init__(self, eps: float = 1e-2):
|
||||
self.uniform = torch.distributions.uniform.Uniform(
|
||||
low=torch.tensor(1.0 - eps, device=get_current_device()),
|
||||
high=torch.tensor(1.0 + eps, device=get_current_device()),
|
||||
low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()),
|
||||
high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()),
|
||||
).rsample
|
||||
|
||||
def __call__(self, inputs: torch.Tensor):
|
||||
|
@ -193,18 +192,13 @@ def create_ep_hierarchical_group(
|
|||
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
|
||||
nproc_per_node = int(nproc_per_node)
|
||||
else:
|
||||
assert dist.get_world_size() % nproc_per_node == 0, \
|
||||
"nproc_per_node should be a divisor of world_size."
|
||||
assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
|
||||
num_node = dist.get_world_size() // nproc_per_node
|
||||
|
||||
intra_src_rank = None
|
||||
ep_intra_node_group = None
|
||||
for i in range(num_node):
|
||||
ep_intra_ranks = [
|
||||
i * nproc_per_node + j
|
||||
for j in range(nproc_per_node)
|
||||
if j in ep_group_ranks
|
||||
]
|
||||
ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
|
||||
group = dist.new_group(ep_intra_ranks)
|
||||
if rank in ep_intra_ranks:
|
||||
assert ep_intra_node_group is None
|
||||
|
@ -212,10 +206,7 @@ def create_ep_hierarchical_group(
|
|||
intra_src_rank = ep_intra_ranks[0]
|
||||
|
||||
ep_inter_node_group = None
|
||||
ep_inter_ranks = [
|
||||
ep_group_ranks[0] + i * nproc_per_node
|
||||
for i in range(num_node)
|
||||
]
|
||||
ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
|
||||
if len(ep_inter_ranks) > 1:
|
||||
group = dist.new_group(ep_inter_ranks)
|
||||
if rank in ep_inter_ranks:
|
||||
|
|
|
@ -7,10 +7,10 @@ import torch.cuda
|
|||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
|
||||
from .base import PipelineSchedule
|
||||
|
@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||
"""
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
|
||||
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
||||
|
||||
def _prepare_inputs_for_interval_stage(self):
|
||||
"""
|
||||
|
|
|
@ -6,10 +6,10 @@ import torch.cuda
|
|||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
|
||||
from .base import PipelineSchedule
|
||||
|
@ -56,7 +56,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
"""
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
|
||||
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
||||
|
||||
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
|
||||
"""Helper method to get the model chunk ID given the iteration number.
|
||||
|
@ -292,7 +292,7 @@ class InterleavedSchedule(PipelineSchedule):
|
|||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
|
|
|
@ -6,10 +6,10 @@ import torch.cuda
|
|||
from torch.nn import Module
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
detach,
|
||||
|
@ -80,7 +80,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
"""
|
||||
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
|
||||
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
|
||||
|
||||
def recv_forward(self, prev_rank: int = None) -> Any:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
|
@ -297,7 +297,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
|
||||
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
|
||||
if return_loss and self.stage_manager.is_last_stage():
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch import nn
|
|||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup, get_world_size
|
||||
|
||||
from colossalai.utils.device import get_current_device, get_rng_state, manual_seed, set_rng_state
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class SeqParallelUtils:
|
||||
|
@ -110,10 +110,10 @@ class Randomizer:
|
|||
# 1. get the current rng state
|
||||
# 2. set the seed and store the rng state
|
||||
# 3. recover the original rng state
|
||||
device_original_rng_state = get_rng_state()
|
||||
manual_seed(seed)
|
||||
self.device_rng_state = get_rng_state()
|
||||
set_rng_state(device_original_rng_state)
|
||||
device_original_rng_state = get_accelerator().get_rng_state()
|
||||
get_accelerator().manual_seed(seed)
|
||||
self.device_rng_state = get_accelerator().get_rng_state()
|
||||
get_accelerator().set_rng_state(device_original_rng_state)
|
||||
|
||||
# to the same for cpu rng state
|
||||
cpu_original_rng_state = torch.get_rng_state()
|
||||
|
@ -122,10 +122,10 @@ class Randomizer:
|
|||
torch.set_rng_state(cpu_original_rng_state)
|
||||
|
||||
def _set_device_rng_state(self, rng_state):
|
||||
set_rng_state(rng_state)
|
||||
get_accelerator().set_rng_state(rng_state)
|
||||
|
||||
def _get_device_rng_state(self):
|
||||
current_state = get_rng_state()
|
||||
current_state = get_accelerator().get_rng_state()
|
||||
return current_state
|
||||
|
||||
def _set_cpu_rng_state(self, rng_state):
|
||||
|
@ -210,7 +210,7 @@ class Randomizer:
|
|||
index = Randomizer.index()
|
||||
if dist.is_initialized():
|
||||
# convert the index to tensor
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
|
||||
|
||||
# all gather the index
|
||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||
|
@ -232,7 +232,7 @@ class Randomizer:
|
|||
|
||||
if dist.is_initialized():
|
||||
# convert the index to tensor
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
|
||||
index_tensor = torch.tensor(index, dtype=torch.int32, device=get_accelerator().get_current_device())
|
||||
|
||||
# all gather the index
|
||||
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
|
||||
|
|
|
@ -9,7 +9,8 @@ from typing import Any, Callable, List
|
|||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from packaging import version
|
||||
from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
def parameterize(argument: str, values: List[Any]) -> Callable:
|
||||
|
@ -199,7 +200,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
|
|||
|
||||
def _wrap_func(f):
|
||||
def _execute_by_gpu_num(*args, **kwargs):
|
||||
num_avail_gpu = device_count()
|
||||
num_avail_gpu = get_accelerator().device_count()
|
||||
if num_avail_gpu >= min_gpus:
|
||||
f(*args, **kwargs)
|
||||
|
||||
|
@ -263,11 +264,11 @@ def clear_cache_before_run():
|
|||
|
||||
def _wrap_func(f):
|
||||
def _clear_cache(*args, **kwargs):
|
||||
empty_cache()
|
||||
reset_peak_memory_stats()
|
||||
reset_max_memory_allocated()
|
||||
reset_max_memory_cached()
|
||||
synchronize()
|
||||
get_accelerator().empty_cache()
|
||||
get_accelerator().reset_peak_memory_stats()
|
||||
get_accelerator().reset_max_memory_allocated()
|
||||
get_accelerator().reset_max_memory_cached()
|
||||
get_accelerator().synchronize()
|
||||
gc.collect()
|
||||
f(*args, **kwargs)
|
||||
|
||||
|
|
|
@ -7,17 +7,12 @@ from .common import (
|
|||
is_ddp_ignored,
|
||||
set_seed,
|
||||
)
|
||||
from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
from .tensor_detector import TensorDetector
|
||||
from .timer import MultiTimer, Timer
|
||||
|
||||
__all__ = [
|
||||
"conditional_context",
|
||||
"get_current_device",
|
||||
"synchronize",
|
||||
"empty_cache",
|
||||
"set_to_cuda",
|
||||
"Timer",
|
||||
"MultiTimer",
|
||||
"multi_tensor_applier",
|
||||
|
@ -28,6 +23,4 @@ __all__ = [
|
|||
"free_storage",
|
||||
"set_seed",
|
||||
"is_ddp_ignored",
|
||||
"set_device",
|
||||
"IS_NPU_AVAILABLE",
|
||||
]
|
||||
|
|
|
@ -1,223 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
IS_NPU_AVAILABLE: bool = False
|
||||
try:
|
||||
import torch_npu # noqa
|
||||
|
||||
IS_NPU_AVAILABLE = torch.npu.is_available()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def set_to_cuda(models):
|
||||
"""Send model to gpu.
|
||||
|
||||
:param models: nn.module or a list of module
|
||||
"""
|
||||
if isinstance(models, list) and len(models) > 1:
|
||||
ret = []
|
||||
for model in models:
|
||||
ret.append(model.to(get_current_device()))
|
||||
return ret
|
||||
elif isinstance(models, list):
|
||||
return models[0].to(get_current_device())
|
||||
else:
|
||||
return models.to(get_current_device())
|
||||
|
||||
|
||||
def get_current_device() -> torch.device:
|
||||
"""
|
||||
Returns currently selected device (gpu/cpu).
|
||||
If cuda available, return gpu, otherwise return cpu.
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
elif IS_NPU_AVAILABLE:
|
||||
return torch.device(f"npu:{torch.npu.current_device()}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def _dispatch_device_func(fn_name: str, *args, **kwargs):
|
||||
if torch.cuda.is_available():
|
||||
return getattr(torch.cuda, fn_name)(*args, **kwargs)
|
||||
elif IS_NPU_AVAILABLE:
|
||||
return getattr(torch.npu, fn_name)(*args, **kwargs)
|
||||
else:
|
||||
raise RuntimeError("No device available")
|
||||
|
||||
|
||||
# device semantics
|
||||
|
||||
|
||||
def can_device_access_peer(device, peer_device) -> bool:
|
||||
return _dispatch_device_func("can_device_access_peer", device, peer_device)
|
||||
|
||||
|
||||
def current_device() -> int:
|
||||
return _dispatch_device_func("current_device")
|
||||
|
||||
|
||||
def current_stream(device=None):
|
||||
return _dispatch_device_func("current_stream", device)
|
||||
|
||||
|
||||
def default_stream(device=None):
|
||||
return _dispatch_device_func("default_stream", device)
|
||||
|
||||
|
||||
def device_count() -> int:
|
||||
return _dispatch_device_func("device_count")
|
||||
|
||||
|
||||
def get_device_capability(device=None) -> Tuple[int, int]:
|
||||
return _dispatch_device_func("get_device_capability", device)
|
||||
|
||||
|
||||
def get_device_name(device=None) -> str:
|
||||
return _dispatch_device_func("get_device_name", device)
|
||||
|
||||
|
||||
def get_device_properties(device):
|
||||
return _dispatch_device_func("get_device_properties", device)
|
||||
|
||||
|
||||
def set_device(index: Optional[int] = None) -> None:
|
||||
if index is None:
|
||||
index = dist.get_rank() % device_count()
|
||||
_dispatch_device_func("set_device", index)
|
||||
|
||||
|
||||
def set_stream(stream_):
|
||||
return _dispatch_device_func("set_stream", stream_)
|
||||
|
||||
|
||||
def stream(stream_):
|
||||
return _dispatch_device_func("stream", stream_)
|
||||
|
||||
|
||||
def synchronize():
|
||||
return _dispatch_device_func("synchronize")
|
||||
|
||||
|
||||
def utilization(device=None) -> int:
|
||||
return _dispatch_device_func("utilization", device)
|
||||
|
||||
|
||||
# random number generator
|
||||
|
||||
|
||||
def get_rng_state(device="cuda") -> torch.Tensor:
|
||||
return _dispatch_device_func("get_rng_state", device)
|
||||
|
||||
|
||||
def get_rng_state_all() -> List[torch.Tensor]:
|
||||
return _dispatch_device_func("get_rng_state_all")
|
||||
|
||||
|
||||
def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
|
||||
return _dispatch_device_func("set_rng_state", new_state, device)
|
||||
|
||||
|
||||
def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
|
||||
return _dispatch_device_func("set_rng_state_all", new_states)
|
||||
|
||||
|
||||
def manual_seed(seed: int) -> None:
|
||||
return _dispatch_device_func("manual_seed", seed)
|
||||
|
||||
|
||||
def manual_seed_all(seed: int) -> None:
|
||||
return _dispatch_device_func("manual_seed_all", seed)
|
||||
|
||||
|
||||
def seed() -> None:
|
||||
return _dispatch_device_func("seed")
|
||||
|
||||
|
||||
def seed_all() -> None:
|
||||
return _dispatch_device_func("seed_all")
|
||||
|
||||
|
||||
def initial_seed() -> int:
|
||||
return _dispatch_device_func("initial_seed")
|
||||
|
||||
|
||||
# streams and events
|
||||
|
||||
|
||||
def Stream(device=None, priority=0, **kwargs):
|
||||
return _dispatch_device_func("Stream", device, priority, **kwargs)
|
||||
|
||||
|
||||
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
|
||||
return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
|
||||
|
||||
|
||||
# memory management
|
||||
|
||||
|
||||
def empty_cache() -> None:
|
||||
return _dispatch_device_func("empty_cache")
|
||||
|
||||
|
||||
def memory_stats(device=None) -> Dict[str, Any]:
|
||||
return _dispatch_device_func("memory_stats", device)
|
||||
|
||||
|
||||
def memory_summary(device=None, abbreviated=False) -> str:
|
||||
return _dispatch_device_func("memory_summary", device, abbreviated)
|
||||
|
||||
|
||||
def memory_snapshot():
|
||||
return _dispatch_device_func("memory_snapshot")
|
||||
|
||||
|
||||
def memory_allocated(device=None) -> int:
|
||||
return _dispatch_device_func("memory_allocated", device)
|
||||
|
||||
|
||||
def max_memory_allocated(device=None) -> int:
|
||||
return _dispatch_device_func("max_memory_allocated", device)
|
||||
|
||||
|
||||
def reset_max_memory_allocated(device=None) -> None:
|
||||
return _dispatch_device_func("reset_max_memory_allocated", device)
|
||||
|
||||
|
||||
def reset_max_memory_cached(device=None) -> None:
|
||||
return _dispatch_device_func("reset_max_memory_cached", device)
|
||||
|
||||
|
||||
def memory_reserved(device=None) -> int:
|
||||
return _dispatch_device_func("memory_reserved", device)
|
||||
|
||||
|
||||
def max_memory_reserved(device=None) -> int:
|
||||
return _dispatch_device_func("max_memory_reserved", device)
|
||||
|
||||
|
||||
def set_per_process_memory_fraction(fraction: float, device=None) -> None:
|
||||
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
|
||||
|
||||
|
||||
def reset_peak_memory_stats(device=None) -> None:
|
||||
return _dispatch_device_func("reset_peak_memory_stats", device)
|
||||
|
||||
|
||||
# amp
|
||||
|
||||
|
||||
def autocast() -> Callable:
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.amp.autocast()
|
||||
elif IS_NPU_AVAILABLE:
|
||||
return torch.npu.amp.autocast()
|
||||
else:
|
||||
raise RuntimeError("No device available")
|
|
@ -3,7 +3,7 @@
|
|||
import time
|
||||
from typing import Tuple
|
||||
|
||||
from .device import synchronize
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class Timer:
|
||||
|
@ -21,13 +21,13 @@ class Timer:
|
|||
|
||||
@property
|
||||
def current_time(self) -> float:
|
||||
synchronize()
|
||||
get_accelerator().synchronize()
|
||||
return time.time()
|
||||
|
||||
def start(self):
|
||||
"""Firstly synchronize cuda, reset the clock and then start the timer."""
|
||||
self._elapsed = 0
|
||||
synchronize()
|
||||
get_accelerator().synchronize()
|
||||
self._start_time = time.time()
|
||||
self._started = True
|
||||
|
||||
|
@ -44,7 +44,7 @@ class Timer:
|
|||
Returns:
|
||||
int: Start-stop interval.
|
||||
"""
|
||||
synchronize()
|
||||
get_accelerator().synchronize()
|
||||
end_time = time.time()
|
||||
elapsed = end_time - self._start_time
|
||||
if keep_in_history:
|
||||
|
|
|
@ -6,8 +6,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.device import IS_NPU_AVAILABLE
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
|
@ -107,7 +106,7 @@ class Chunk:
|
|||
self.valid_end = self.shard_size
|
||||
|
||||
self.dtype = dtype
|
||||
device = init_device or get_current_device()
|
||||
device = init_device or get_accelerator().get_current_device()
|
||||
|
||||
# chunk_temp is a global chunk, which only exists during building the chunks.
|
||||
self.chunk_temp = torch.zeros(chunk_size, dtype=dtype, device=device) # keep all zero
|
||||
|
@ -125,7 +124,7 @@ class Chunk:
|
|||
# configure the init device of the shard
|
||||
# no-offload default: fp16, fp32 -> CUDA
|
||||
# offload default: fp16, fp32 -> CPU
|
||||
self.shard_device = torch.device("cpu") if cpu_shard_init else get_current_device()
|
||||
self.shard_device = torch.device("cpu") if cpu_shard_init else get_accelerator().get_current_device()
|
||||
|
||||
self.chunk_mem = self.chunk_size * self.chunk_temp.element_size()
|
||||
self.shard_mem = self.chunk_mem // self.pg_size
|
||||
|
@ -192,10 +191,7 @@ class Chunk:
|
|||
if self.chunk_temp is not None:
|
||||
return self.chunk_temp.device.type
|
||||
else:
|
||||
if self.is_gathered or self.cuda_shard is not None:
|
||||
return "npu" if IS_NPU_AVAILABLE else "cuda"
|
||||
else:
|
||||
return "cpu"
|
||||
return get_accelerator().name
|
||||
|
||||
@property
|
||||
def payload(self) -> torch.Tensor:
|
||||
|
@ -297,7 +293,7 @@ class Chunk:
|
|||
self.valid_end = self.utilized_size - self.shard_begin
|
||||
|
||||
if self.chunk_temp.device.type == "cpu":
|
||||
self.cuda_global_chunk = self.chunk_temp.to(get_current_device())
|
||||
self.cuda_global_chunk = self.chunk_temp.to(get_accelerator().get_current_device())
|
||||
self.__update_tensors_ptr()
|
||||
else:
|
||||
self.cuda_global_chunk = self.chunk_temp
|
||||
|
@ -334,12 +330,12 @@ class Chunk:
|
|||
return
|
||||
|
||||
if device.type == "cuda" or device.type == "npu":
|
||||
assert device == get_current_device(), "can't move chunk to another device"
|
||||
assert device == get_accelerator().get_current_device(), "can't move chunk to another device"
|
||||
|
||||
if self.cuda_shard:
|
||||
return
|
||||
|
||||
self.cuda_shard = self.cpu_shard.to(get_current_device())
|
||||
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
|
||||
|
||||
if not self.pin_memory:
|
||||
self.cpu_shard = None
|
||||
|
@ -394,7 +390,9 @@ class Chunk:
|
|||
if self.extra_dp_group is not None:
|
||||
dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
|
||||
else:
|
||||
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
|
||||
self.cuda_shard = torch.empty(
|
||||
self.shard_size, dtype=self.dtype, device=get_accelerator().get_current_device()
|
||||
)
|
||||
|
||||
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
|
||||
|
@ -533,7 +531,7 @@ class Chunk:
|
|||
# only be called when optimizer state is in CPU memory
|
||||
# the grad and param should be in the same device
|
||||
assert self.cuda_shard is None
|
||||
temp = optim_chunk.cpu_shard.to(get_current_device())
|
||||
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
# avoid to transform FP32 in CPU
|
||||
self.cuda_shard = temp.to(self.dtype)
|
||||
|
||||
|
@ -631,7 +629,7 @@ class Chunk:
|
|||
grad_chunk.valid_end = self.valid_end
|
||||
|
||||
if grad_chunk.chunk_temp.device.type == "cpu":
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_current_device())
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp.to(get_accelerator().get_current_device())
|
||||
else:
|
||||
grad_chunk.cuda_global_chunk = grad_chunk.chunk_temp
|
||||
grad_chunk.chunk_temp = None
|
||||
|
|
|
@ -5,7 +5,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.utils import free_storage, get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.utils import free_storage
|
||||
|
||||
from .chunk import Chunk, ChunkFullError, TensorState
|
||||
|
||||
|
@ -20,7 +21,7 @@ class ChunkManager:
|
|||
"""
|
||||
|
||||
def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
|
||||
self.device = init_device or get_current_device()
|
||||
self.device = init_device or get_accelerator().get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
self.kwargs_config = chunk_configuration
|
||||
for k, v in self.kwargs_config.items():
|
||||
|
@ -107,7 +108,7 @@ class ChunkManager:
|
|||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_current_device())
|
||||
chunk.shard_move(get_accelerator().get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
|
||||
|
@ -276,7 +277,10 @@ class ChunkManager:
|
|||
accumulated_grad = chunk.grad_chunk.cuda_shard.clone().detach().mul_(chunk.pg_size)
|
||||
else:
|
||||
accumulated_grad = (
|
||||
chunk.grad_chunk.cpu_shard.to(get_current_device()).clone().detach().mul_(chunk.pg_size)
|
||||
chunk.grad_chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
.clone()
|
||||
.detach()
|
||||
.mul_(chunk.pg_size)
|
||||
)
|
||||
accumulated_grad_gathered = False
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyTensor
|
||||
|
@ -27,7 +28,7 @@ from colossalai.tensor.d_tensor import (
|
|||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
|
||||
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
|
||||
from .gemini_hook import GeminiZeROHook
|
||||
|
@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper):
|
|||
|
||||
# move ignored parameters to CUDA
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
|
||||
p.data = p.data.to(device=get_accelerator().get_current_device(), dtype=self.mixed_precision)
|
||||
continue
|
||||
|
||||
# create a fp16 parameter
|
||||
|
@ -815,7 +816,7 @@ class GeminiDDP(ModelWrapper):
|
|||
for buffer in self.module.buffers():
|
||||
if isinstance(buffer, LazyTensor):
|
||||
buffer.materialize()
|
||||
buffer.data = buffer.to(get_current_device())
|
||||
buffer.data = buffer.to(get_accelerator().get_current_device())
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch.distributed import ProcessGroup
|
|||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
@ -26,7 +27,7 @@ from colossalai.tensor.d_tensor import (
|
|||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
from colossalai.utils import disposable, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .gemini_ddp import GeminiDDP
|
||||
|
@ -233,7 +234,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
|
||||
grad_chunk.l2_norm = None # clear l2 norm
|
||||
|
||||
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||
comm_buffer = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device())
|
||||
for group, part_norm in group_to_norm.items():
|
||||
comm_buffer.fill_(part_norm)
|
||||
dist.all_reduce(comm_buffer, group=group)
|
||||
|
@ -314,10 +315,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
continue
|
||||
|
||||
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
|
||||
self.chunk_manager.move_chunk(chunk32, get_current_device())
|
||||
self.chunk_manager.move_chunk(chunk32, get_accelerator().get_current_device())
|
||||
# stores grad now
|
||||
self.chunk_manager.move_chunk(chunk16, get_current_device())
|
||||
self.module.set_chunk_grad_device(chunk16, get_current_device())
|
||||
self.chunk_manager.move_chunk(chunk16, get_accelerator().get_current_device())
|
||||
self.module.set_chunk_grad_device(chunk16, get_accelerator().get_current_device())
|
||||
fp32_params_used_cuda_margin_mem += chunk32.payload_mem
|
||||
|
||||
for group in self.param_groups:
|
||||
|
@ -328,7 +329,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
state = self.optim.state[fake_param]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(get_current_device())
|
||||
state[k] = v.to(get_accelerator().get_current_device())
|
||||
|
||||
def _register_states_(self):
|
||||
for group in self.optim.param_groups:
|
||||
|
@ -551,7 +552,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
self,
|
||||
param_id: int,
|
||||
state_names: list,
|
||||
device: torch.device = get_current_device(),
|
||||
device: torch.device = get_accelerator().get_current_device(),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.zero.gemini.chunk import ChunkManager
|
||||
|
||||
from .memory_stats import MemStats
|
||||
|
@ -33,4 +33,4 @@ class ChunkMemStatsCollector(MemStatsCollector):
|
|||
def cuda_margin_mem(self) -> float:
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
|
||||
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda
|
||||
return colo_device_memory_capacity(get_accelerator().get_current_device()) - self._memstats.max_overall_cuda
|
||||
|
|
|
@ -5,7 +5,7 @@ from time import sleep, time
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
|
@ -77,7 +77,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
|
|||
super().__init__()
|
||||
self.keep_measuring = False
|
||||
|
||||
current_device = get_current_device()
|
||||
current_device = get_accelerator().get_current_device()
|
||||
|
||||
def _set_cuda_device():
|
||||
torch.cuda.set_device(current_device)
|
||||
|
@ -116,7 +116,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
|
|||
while self.keep_measuring:
|
||||
max_usage = max(
|
||||
max_usage,
|
||||
colo_device_memory_used(get_current_device()),
|
||||
colo_device_memory_used(get_accelerator().get_current_device()),
|
||||
)
|
||||
sleep(self.interval)
|
||||
return max_usage
|
||||
|
|
|
@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.chunk import Chunk
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
|
@ -85,7 +85,7 @@ class StaticPlacementPolicy(PlacementPolicy):
|
|||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered or offloaded_optim_chunk_mem >= offload_optim_chunk_mem:
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# real offloaded mem is chunk.shard_mem, for simplicity we use chunk mem here
|
||||
|
@ -140,7 +140,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
int: the volume of memory that is evicted
|
||||
"""
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
|
||||
used_cuda_model_data = self.chunk_manager.total_mem["cuda"]
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
|
@ -194,7 +194,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
# init offload optim settings
|
||||
# keep gathered chunks are in CUDA
|
||||
if chunk.keep_gathered:
|
||||
grads_device_map[p] = get_current_device()
|
||||
grads_device_map[p] = get_accelerator().get_current_device()
|
||||
else:
|
||||
grads_device_map[p] = torch.device("cpu")
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
from .chunk import Chunk
|
||||
|
||||
|
@ -18,11 +18,11 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk, dtype: torch.dtype):
|
|||
if chunk.cuda_shard is not None:
|
||||
shard_temp = chunk.cuda_shard
|
||||
else:
|
||||
shard_temp = chunk.cpu_shard.to(get_current_device())
|
||||
shard_temp = chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
|
||||
shard_temp = shard_temp.to(dtype)
|
||||
|
||||
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_current_device())
|
||||
total_temp = torch.zeros(chunk.chunk_size, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
gather_list = list(torch.chunk(input=total_temp, chunks=chunk.pg_size, dim=0))
|
||||
dist.all_gather(tensor_list=gather_list, tensor=shard_temp, group=chunk.torch_pg)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import colossalai.utils.device as device_utils
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||
BF16MixedPrecisionMixin,
|
||||
FP16MixedPrecisionMixin,
|
||||
|
@ -22,9 +22,6 @@ from colossalai.interface import OptimizerWrapper
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.moe_tensor.api import is_moe_tensor
|
||||
|
||||
# from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
|
||||
|
@ -183,7 +180,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
# intialize communication stream for
|
||||
# communication-compuation overlapping
|
||||
if self._overlap_communication:
|
||||
self._comm_stream = device_utils.Stream()
|
||||
self._comm_stream = get_accelerator().Stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# or stage 2 is used
|
||||
|
@ -217,7 +214,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
return len(self._working_param_groups)
|
||||
|
||||
def _sanity_checks(self):
|
||||
assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
|
||||
assert get_accelerator().name in ["cuda", "npu"], "device is required"
|
||||
for param_group in self.optim.param_groups:
|
||||
group_params = param_group["params"]
|
||||
for param in group_params:
|
||||
|
@ -228,7 +225,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
def _create_master_param_current_rank(self, param_list):
|
||||
# split each param evenly by world size
|
||||
params_current_rank = []
|
||||
device = "cpu" if self._cpu_offload else get_current_device()
|
||||
device = "cpu" if self._cpu_offload else get_accelerator().get_current_device()
|
||||
|
||||
for param in param_list:
|
||||
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
|
||||
|
@ -340,11 +337,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if len(moe_grad_list) > 0:
|
||||
moe_flat_grads.record_stream(stream)
|
||||
# waiting for ops in the default stream finishing
|
||||
stream.wait_stream(device_utils.current_stream())
|
||||
stream.wait_stream(get_accelerator().current_stream())
|
||||
else:
|
||||
stream = device_utils.current_stream()
|
||||
stream = get_accelerator().current_stream()
|
||||
|
||||
with device_utils.stream(stream):
|
||||
with get_accelerator().stream(stream):
|
||||
group_id = self._bucket_store.current_group_id
|
||||
|
||||
if self.moe_extra_dp_pg is None:
|
||||
|
@ -486,7 +483,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
device_utils.synchronize()
|
||||
get_accelerator().synchronize()
|
||||
|
||||
self.zero_grad()
|
||||
|
||||
|
@ -505,7 +502,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
# clear reduced grads
|
||||
if self._overlap_communication:
|
||||
device_utils.synchronize()
|
||||
get_accelerator().synchronize()
|
||||
|
||||
self.zero_grad()
|
||||
|
||||
|
@ -621,7 +618,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
||||
# update working partition updated by the current rank
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
|
@ -661,7 +658,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
norm_type = float(norm_type)
|
||||
if norm_type == inf:
|
||||
total_norm = max(grad.data.abs().max() for grad in gradients)
|
||||
total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
|
||||
total_norm_cuda = torch.tensor(
|
||||
[float(total_norm)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
|
||||
total_norm = total_norm_cuda.item()
|
||||
|
||||
|
@ -673,7 +672,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
|
||||
# Sum across all model parallel GPUs.
|
||||
total_norm_exponentiated_cuda = torch.tensor(
|
||||
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
|
||||
[float(total_norm_exponentiated)], device=get_accelerator().get_current_device(), dtype=torch.float
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
|
||||
|
@ -765,7 +764,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
Dict: the pytorch form state_dict
|
||||
"""
|
||||
zero_state = dict()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
for param, state in self.optim.state.items():
|
||||
zero_state[param] = copy.deepcopy(state)
|
||||
for k, v in state.items():
|
||||
|
@ -827,7 +826,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
ret_block = dict()
|
||||
ret_block_size = 0
|
||||
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
local_states = self.optim.state_dict()["state"]
|
||||
for param_idx, states in local_states.items():
|
||||
current_block_size = 0
|
||||
|
|
|
@ -45,7 +45,6 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
```
|
||||
## Define Plugin
|
||||
Create a `HybridParallelPlugin` object and specify the desired parallelism strategies to be used. In this example, both pipeline parallelism and ZeRO-1 are used simultaneously.
|
||||
|
|
|
@ -43,7 +43,6 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
```
|
||||
### 定义plugin
|
||||
定义一个[`HybridParallelPlugin`](../basics/booster_plugins.md)对象,指定所需要使用的并行策略,在该例子中,同时使用了流水线并行和zero1.
|
||||
|
|
|
@ -16,10 +16,10 @@ from utils.global_vars import get_tensorboard_writer, get_timers, set_global_var
|
|||
from utils.logger import Logger
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
|
||||
|
@ -53,7 +53,7 @@ def main():
|
|||
set_global_variables(launch_time, args.tensorboard_path)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
get_current_device()
|
||||
get_accelerator().get_current_device()
|
||||
|
||||
# build model, optimizer and criterion
|
||||
if args.distplan.startswith("CAI"):
|
||||
|
@ -67,7 +67,10 @@ def main():
|
|||
|
||||
# build GPT model
|
||||
with ColoInitContext(
|
||||
device=get_current_device(), dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg
|
||||
device=get_accelerator().get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=shard_pg,
|
||||
):
|
||||
config, model, numel = get_model(args, logger)
|
||||
|
||||
|
@ -78,7 +81,7 @@ def main():
|
|||
elif args.distplan == "CAI_Gemini":
|
||||
gemini_config = dict(
|
||||
strict_ddp_mode=args.tp_degree == 1,
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
placement_policy=args.placement,
|
||||
pin_memory=True,
|
||||
hidden_dim=model.config.hidden_size,
|
||||
|
|
|
@ -20,11 +20,11 @@ from tqdm.auto import tqdm
|
|||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
|
@ -386,7 +386,7 @@ def main(args):
|
|||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32
|
||||
torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
|
@ -401,7 +401,7 @@ def main(args):
|
|||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
||||
|
||||
pipeline.to(get_current_device())
|
||||
pipeline.to(get_accelerator().get_current_device())
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader,
|
||||
|
@ -578,8 +578,8 @@ def main(args):
|
|||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
vae.to(get_current_device(), dtype=weight_dtype)
|
||||
text_encoder.to(get_current_device(), dtype=weight_dtype)
|
||||
vae.to(get_accelerator().get_current_device(), dtype=weight_dtype)
|
||||
text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
||||
|
@ -613,7 +613,7 @@ def main(args):
|
|||
torch.cuda.reset_peak_memory_stats()
|
||||
# Move batch to gpu
|
||||
for key, value in batch.items():
|
||||
batch[key] = value.to(get_current_device(), non_blocking=True)
|
||||
batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True)
|
||||
|
||||
# Convert images to latent space
|
||||
optimizer.zero_grad()
|
||||
|
|
|
@ -21,13 +21,13 @@ from tqdm.auto import tqdm
|
|||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
|
@ -385,7 +385,7 @@ def main(args):
|
|||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32
|
||||
torch_dtype = torch.float16 if get_accelerator().get_current_device() == "cuda" else torch.float32
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
|
@ -400,7 +400,7 @@ def main(args):
|
|||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
||||
|
||||
pipeline.to(get_current_device())
|
||||
pipeline.to(get_accelerator().get_current_device())
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader,
|
||||
|
@ -598,8 +598,8 @@ def main(args):
|
|||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
vae.to(get_current_device(), dtype=weight_dtype)
|
||||
text_encoder.to(get_current_device(), dtype=weight_dtype)
|
||||
vae.to(get_accelerator().get_current_device(), dtype=weight_dtype)
|
||||
text_encoder.to(get_accelerator().get_current_device(), dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
||||
|
@ -633,7 +633,7 @@ def main(args):
|
|||
torch.cuda.reset_peak_memory_stats()
|
||||
# Move batch to gpu
|
||||
for key, value in batch.items():
|
||||
batch[key] = value.to(get_current_device(), non_blocking=True)
|
||||
batch[key] = value.to(get_accelerator().get_current_device(), non_blocking=True)
|
||||
|
||||
# Convert images to latent space
|
||||
optimizer.zero_grad()
|
||||
|
|
|
@ -13,12 +13,12 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import tqdm
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Prepare Hyperparameters
|
||||
|
@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
|
|||
@torch.no_grad()
|
||||
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
|
||||
model.eval()
|
||||
correct = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
total = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
|
||||
total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
|
||||
for images, labels in test_dataloader:
|
||||
images = images.cuda()
|
||||
labels = labels.cuda()
|
||||
|
|
|
@ -33,9 +33,10 @@ def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224
|
|||
|
||||
|
||||
def colo_memory_cap(size_in_GB):
|
||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
|
||||
if size_in_GB * (1024**3) < cuda_capacity:
|
||||
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
||||
print(f"Limiting GPU memory usage to {size_in_GB} GB")
|
||||
|
|
|
@ -6,10 +6,9 @@ import torch.distributed as dist
|
|||
import transformers
|
||||
|
||||
import colossalai
|
||||
import colossalai.utils.device as device_utils
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
@ -52,7 +51,7 @@ CONFIG_MAP = {
|
|||
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
|
||||
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return data
|
||||
|
@ -97,9 +96,9 @@ def print_details_info(outputs, model_config, args, whole_end2end):
|
|||
msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n"
|
||||
msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n"
|
||||
msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n"
|
||||
msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n"
|
||||
msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n"
|
||||
msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n"
|
||||
|
||||
print(msg)
|
||||
|
||||
|
|
|
@ -5,9 +5,9 @@ import torch.distributed as dist
|
|||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.inference import InferenceEngine
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
INPUT_TEXTS = [
|
||||
"What is the longest river in the world?",
|
||||
|
@ -57,7 +57,7 @@ def run_inference(args):
|
|||
)
|
||||
|
||||
inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True)
|
||||
inputs = {k: v.to(get_current_device()) for k, v in inputs.items()}
|
||||
inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()}
|
||||
outputs = engine.generate(inputs)
|
||||
|
||||
if rank == 0:
|
||||
|
|
|
@ -18,11 +18,11 @@ from transformers import (
|
|||
)
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Prepare Hyperparameters
|
||||
|
@ -59,7 +59,7 @@ def evaluate_model(
|
|||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
for batch in dataloader:
|
||||
batch = move_to_cuda(batch)
|
||||
labels = batch["labels"]
|
||||
|
@ -88,8 +88,10 @@ def evaluate_model(
|
|||
object_list = [None, None]
|
||||
dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
|
||||
|
||||
metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
|
||||
accum_loss.add_(object_list[1].to(get_current_device()))
|
||||
metric.add_batch(
|
||||
predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels
|
||||
)
|
||||
accum_loss.add_(object_list[1].to(get_accelerator().get_current_device()))
|
||||
|
||||
else:
|
||||
batch = move_to_cuda(batch)
|
||||
|
|
|
@ -7,13 +7,13 @@ from model_zoo import GPTLMLoss, get_gpt2_components
|
|||
from torch.utils._pytree import tree_map
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
|
||||
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
|
||||
from colossalai.auto_parallel.offload.solver import NOT_NVML
|
||||
from colossalai.fx.profiler import parameter_size
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -41,7 +41,7 @@ def train_gpt(args):
|
|||
64,
|
||||
8,
|
||||
),
|
||||
device=get_current_device(),
|
||||
device=get_accelerator().get_current_device(),
|
||||
)
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
|
|
|
@ -12,12 +12,12 @@ from commons.utils import get_data, get_profile_context, get_tflops, get_time_st
|
|||
from packaging import version
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
CAI_VERSION = colossalai.__version__
|
||||
|
||||
|
@ -141,7 +141,11 @@ def main():
|
|||
criterion = GPTLMLoss()
|
||||
torch.manual_seed(123)
|
||||
if args.distplan.startswith("CAI"):
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
|
||||
ctx = (
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if args.distplan == "CAI_Gemini"
|
||||
else nullcontext()
|
||||
)
|
||||
# build GPT model
|
||||
with ctx:
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
|
|
|
@ -13,11 +13,11 @@ from tqdm import tqdm
|
|||
from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Prepare Hyperparameters
|
||||
|
@ -54,7 +54,7 @@ def evaluate_model(
|
|||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
||||
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
for batch in dataloader:
|
||||
batch = move_to_cuda(batch)
|
||||
labels = batch["labels"]
|
||||
|
@ -83,8 +83,10 @@ def evaluate_model(
|
|||
object_list = [None, None]
|
||||
dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group)
|
||||
|
||||
metric.add_batch(predictions=object_list[0].to(get_current_device()), references=labels)
|
||||
accum_loss.add_(object_list[1].to(get_current_device()))
|
||||
metric.add_batch(
|
||||
predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels
|
||||
)
|
||||
accum_loss.add_(object_list[1].to(get_accelerator().get_current_device()))
|
||||
|
||||
else:
|
||||
batch = move_to_cuda(batch)
|
||||
|
|
|
@ -5,6 +5,7 @@ from torch import nn as nn
|
|||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context import ParallelMode, seed
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.base_layer import ParallelLayer
|
||||
|
@ -12,7 +13,6 @@ from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_b
|
|||
from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row
|
||||
from colossalai.legacy.nn.layer.utils import divide
|
||||
from colossalai.legacy.registry import LAYERS, LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
@ -96,7 +96,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
|
||||
position_ids = torch.arange(
|
||||
0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device()
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
|
@ -194,7 +196,7 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
|||
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
|
@ -439,7 +441,9 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
|||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
|
||||
position_ids = torch.arange(
|
||||
0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device()
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
|
@ -532,7 +536,7 @@ class HiddenParallelEmbedding1D(torch.nn.Module):
|
|||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
|
|
|
@ -13,13 +13,12 @@ from transformers.models.llama.configuration_llama import LlamaConfig
|
|||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
import colossalai.utils.device as device_utils
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Constants
|
||||
|
@ -166,7 +165,7 @@ def main():
|
|||
# Initialize Model and Optimizer
|
||||
# ==============================
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device())
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
|
@ -197,7 +196,9 @@ def main():
|
|||
torch.set_default_dtype(torch.bfloat16)
|
||||
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
|
||||
torch.set_default_dtype(torch.float)
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CUDA memory: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB"
|
||||
)
|
||||
coordinator.print_on_master(
|
||||
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
|
||||
)
|
||||
|
@ -223,7 +224,7 @@ def main():
|
|||
performance_evaluator.on_step_end(**batch)
|
||||
|
||||
performance_evaluator.on_fit_end()
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
|
||||
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -8,7 +8,7 @@ from torch.distributed import ProcessGroup
|
|||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.accelerator import get_accelerator
|
||||
|
||||
|
||||
class StatefulDistributedSampler(DistributedSampler):
|
||||
|
@ -108,7 +108,9 @@ class RandomDataset(Dataset):
|
|||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.input_ids = torch.randint(
|
||||
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
||||
)
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -21,13 +21,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int:
|
||||
|
@ -191,7 +191,9 @@ def main():
|
|||
config = LlamaConfig.from_pretrained(args.model_path)
|
||||
# use lazy init when using GeminiPlugin
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if isinstance(plugin, GeminiPlugin)
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with init_ctx:
|
||||
|
|
|
@ -5,9 +5,8 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
import colossalai.utils.device as device_utils
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils.device import get_current_device
|
||||
|
||||
|
||||
def divide(x: float, y: float) -> float:
|
||||
|
@ -22,7 +21,7 @@ def divide(x: float, y: float) -> float:
|
|||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=get_current_device())
|
||||
tensor = torch.tensor([x], device=get_accelerator().get_current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
@ -86,13 +85,13 @@ class PerformanceEvaluator:
|
|||
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
|
||||
if self.disable:
|
||||
return
|
||||
device_utils.synchronize()
|
||||
get_accelerator().synchronize()
|
||||
self.timer.start()
|
||||
|
||||
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
device_utils.synchronize()
|
||||
get_accelerator().synchronize()
|
||||
self.timer.end()
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
|
|
@ -20,13 +20,13 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|||
from transformers.models.llama.tokenization_llama import LlamaTokenizer
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
MODEL_CONFIGS = {
|
||||
"7b": LlamaConfig(max_position_embeddings=4096),
|
||||
|
@ -227,7 +227,9 @@ def main():
|
|||
config = MODEL_CONFIGS[args.config]
|
||||
# use lazy init when using GeminiPlugin
|
||||
init_ctx = (
|
||||
LazyInitContext(default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if isinstance(plugin, GeminiPlugin)
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with init_ctx:
|
||||
|
|
|
@ -14,6 +14,7 @@ from transformers.models.llama import LlamaConfig
|
|||
from utils import PerformanceEvaluator, get_model_numel
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
@ -21,7 +22,6 @@ from colossalai.moe.layers import apply_load_balance
|
|||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
|
@ -64,13 +64,15 @@ class RandomDataset(Dataset):
|
|||
)
|
||||
self.input_ids.append(encode["input_ids"])
|
||||
self.attention_mask.append(encode["attention_mask"])
|
||||
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
|
||||
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
|
||||
self.input_ids = torch.cat(self.input_ids, dim=0).to(get_accelerator().get_current_device())
|
||||
self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_accelerator().get_current_device())
|
||||
repeat_times = num_samples // self.input_ids.shape[0] + 1
|
||||
self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
|
||||
self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
|
||||
else:
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.input_ids = torch.randint(
|
||||
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
||||
)
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -15,6 +15,7 @@ from transformers import T5Tokenizer
|
|||
from transformers.models.llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
@ -22,7 +23,6 @@ from colossalai.moe.layers import apply_load_balance
|
|||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
|
@ -61,7 +61,9 @@ class RandomDataset(Dataset):
|
|||
def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
|
||||
self.num_samples = num_samples
|
||||
self.max_length = max_length
|
||||
self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
|
||||
self.input_ids = torch.randint(
|
||||
0, vocab_size, (num_samples, max_length), device=get_accelerator().get_current_device()
|
||||
)
|
||||
self.attention_mask = torch.ones_like(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -14,12 +14,12 @@ from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# constants
|
||||
|
||||
|
@ -159,7 +159,11 @@ if args.distplan == "colossalai":
|
|||
logger.info(f"plugin: {plugin}")
|
||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||
|
||||
ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == "gemini" else nullcontext()
|
||||
ctx = (
|
||||
LazyInitContext(default_device=get_accelerator().get_current_device())
|
||||
if args.plugin == "gemini"
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with ctx:
|
||||
model = PaLM(num_tokens=50304, dim=4096, depth=64)
|
||||
|
|
|
@ -13,12 +13,12 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import tqdm
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Prepare Hyperparameters
|
||||
|
@ -53,8 +53,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
|
|||
@torch.no_grad()
|
||||
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
|
||||
model.eval()
|
||||
correct = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
total = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
|
||||
total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
|
||||
for images, labels in test_dataloader:
|
||||
images = images.cuda()
|
||||
labels = labels.cuda()
|
||||
|
|
|
@ -13,13 +13,13 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import tqdm
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.lr_scheduler import LinearWarmupLR
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Prepare Hyperparameters
|
||||
|
@ -73,8 +73,8 @@ def build_dataloader(batch_size: int, coordinator: DistCoordinator, plugin: DPPl
|
|||
@torch.no_grad()
|
||||
def evaluate(model: nn.Module, test_dataloader: DataLoader, coordinator: DistCoordinator) -> float:
|
||||
model.eval()
|
||||
correct = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
total = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
correct = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
|
||||
total = torch.zeros(1, dtype=torch.int64, device=get_accelerator().get_current_device())
|
||||
for images, labels in test_dataloader:
|
||||
images = images.cuda()
|
||||
labels = labels.cuda()
|
||||
|
|
|
@ -12,11 +12,11 @@ from tqdm import tqdm
|
|||
from transformers import AutoConfig, BertForSequenceClassification, get_linear_schedule_with_warmup
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
# ==============================
|
||||
# Prepare Hyperparameters
|
||||
|
@ -45,7 +45,7 @@ def evaluate(
|
|||
model.eval()
|
||||
|
||||
def evaluate_subset(dataloader: DataLoader):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
||||
for batch in dataloader:
|
||||
batch = move_to_cuda(batch)
|
||||
outputs = model(**batch)
|
||||
|
|
|
@ -51,13 +51,13 @@ from transformers import (
|
|||
from transformers.utils.versions import require_version
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.tensor import ProcessGroup
|
||||
from colossalai.legacy.utils import get_dataloader
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import GeminiOptimizer
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||
|
@ -249,9 +249,9 @@ def parse_args():
|
|||
|
||||
|
||||
def colo_memory_cap(size_in_GB):
|
||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
|
||||
if size_in_GB * (1024**3) < cuda_capacity:
|
||||
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
||||
print("Using {} GB of GPU memory".format(size_in_GB))
|
||||
|
@ -265,7 +265,9 @@ class DummyDataloader:
|
|||
self.vocab_size = vocab_size
|
||||
|
||||
def generate(self):
|
||||
input_ids = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len), device=get_current_device())
|
||||
input_ids = torch.randint(
|
||||
0, self.vocab_size, (self.batch_size, self.seq_len), device=get_accelerator().get_current_device()
|
||||
)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids}
|
||||
|
||||
|
@ -390,7 +392,7 @@ def main():
|
|||
if args.init_in_cpu:
|
||||
init_dev = torch.device("cpu")
|
||||
else:
|
||||
init_dev = get_current_device()
|
||||
init_dev = get_accelerator().get_current_device()
|
||||
|
||||
cai_version = colossalai.__version__
|
||||
logger.info(f"using Colossal-AI version {cai_version}")
|
||||
|
@ -439,7 +441,9 @@ def main():
|
|||
except ImportError:
|
||||
# this works for unreleased main branch, and this may be released on 0.2.9
|
||||
from colossalai.zero import GeminiDDP
|
||||
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
|
||||
model = GeminiDDP(
|
||||
model, device=get_accelerator().get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True
|
||||
)
|
||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue