diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py index 79661a444..439d13dcf 100644 --- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py @@ -8,6 +8,7 @@ import torch from torch import Tensor from colossalai.logging import get_dist_logger +from colossalai.utils.device import get_current_device __all__ = ["BaseGradScaler"] @@ -22,7 +23,7 @@ class BaseGradScaler(ABC): def __init__(self, initial_scale: float, verbose: bool): assert initial_scale > 0 - self._scale = torch.cuda.FloatTensor([initial_scale]) + self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float) self._verbose = verbose if self._verbose: diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py index 65133a4b3..86ba919ee 100644 --- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py +++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py @@ -5,6 +5,8 @@ from typing import Optional import torch +from colossalai.utils.device import get_current_device + from .base_grad_scaler import BaseGradScaler __all__ = ["DynamicGradScaler"] @@ -37,12 +39,12 @@ class DynamicGradScaler(BaseGradScaler): ): super().__init__(initial_scale, verbose) if min_scale: - self._min_scale = torch.cuda.FloatTensor([min_scale]) + self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float) else: self._min_scale = None if max_scale: - self._max_scale = torch.cuda.FloatTensor([max_scale]) + self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float) else: self._max_scale = None @@ -115,7 +117,7 @@ class DynamicGradScaler(BaseGradScaler): return state_dict def load_state_dict(self, state_dict): - self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) + self._scale = state_dict["scale"].to(get_current_device()) self._growth_factor = state_dict["growth_factor"] self._backoff_factor = state_dict["backoff_factor"] self._hysteresis = state_dict["hysteresis"] diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index a6b4904f2..a6628e29c 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -11,7 +11,7 @@ except: import torch from torch.fx.node import Node -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from .region import Region from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 0908fa40d..963e5a71c 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -25,6 +25,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.utils import get_current_device +from colossalai.utils.device import IS_NPU_AVAILABLE from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -37,6 +38,7 @@ PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 + def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A mapping from integer param_id to param32 shape. @@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer): start_index += len(group["params"]) return param_info + + class GeminiCheckpointIO(GeneralCheckpointIO): def __init__(self) -> None: super().__init__() @@ -359,6 +363,8 @@ class GeminiPlugin(DPPluginBase): ) -> None: super().__init__() assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" + if IS_NPU_AVAILABLE: + assert placement_policy == "static", "NPU only supports static placement policy" self.gemini_config = dict( chunk_config_dict=chunk_config_dict, chunk_init_device=(chunk_init_device or get_current_device()), @@ -437,7 +443,7 @@ class GeminiPlugin(DPPluginBase): return True def supported_devices(self) -> List[str]: - return ["cuda"] + return ["cuda", "npu"] def configure( self, @@ -485,4 +491,4 @@ class GeminiPlugin(DPPluginBase): return GeminiCheckpointIO() def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 09343138f..89102820c 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -306,7 +306,7 @@ class LowLevelZeroPlugin(DPPluginBase): return True def supported_devices(self) -> List[str]: - return ["cuda"] + return ["cuda", "npu"] def configure( self, diff --git a/colossalai/initialize.py b/colossalai/initialize.py index aac57d34a..25076b742 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -11,7 +11,7 @@ import torch.distributed as dist from colossalai.context import Config from colossalai.logging import get_dist_logger -from colossalai.utils import set_device, set_seed +from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed def launch( @@ -47,12 +47,15 @@ def launch( if rank == 0: warnings.warn("`config` is deprecated and will be removed soon.") + if IS_NPU_AVAILABLE and backend == "nccl": + backend = "hccl" + # init default process group init_method = f"tcp://[{host}]:{port}" dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # set cuda device - if torch.cuda.is_available(): + if torch.cuda.is_available() or IS_NPU_AVAILABLE: # if local rank is not given, calculate automatically set_device(local_rank) diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h index bf9b85997..db1f26d5f 100644 --- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h @@ -142,6 +142,7 @@ class Adam_Optimizer { } } +#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__) inline void simd_load(bool is_half, float *ptr, __half *h_ptr, AVX_Data &data) { if (is_half) { @@ -159,6 +160,7 @@ class Adam_Optimizer { SIMD_STORE(ptr, data.data); } } +#endif void step(size_t step, float lr, float beta1, float beta2, float epsilon, float weight_decay, bool bias_correction, torch::Tensor ¶ms, diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp new file mode 100644 index 000000000..a715a2711 --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp @@ -0,0 +1,304 @@ +#include "cpu_adam_arm.h" + +void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg, + void *_exp_avg_sq, size_t _param_size, + at::ScalarType param_dtype, + at::ScalarType grad_dtype, + at::ScalarType exp_avg_dtype, + at::ScalarType exp_avg_sq_dtype, float loss_scale) { + size_t rounded_size = 0; +#if defined(__aarch64__) + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH); +#endif + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + +#if defined(__aarch64__) + float32x4_t betta1_4 = simd_set(_betta1); + float32x4_t betta2_4 = simd_set(_betta2); + float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); + float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); + float32x4_t bias2_sqrt = simd_set(_bias_correction2); + float32x4_t eps_4 = simd_set(_eps); + float32x4_t step_size_4 = simd_set(step_size); + float32x4_t weight_decay_4; + if (_weight_decay > 0) { + weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); + } + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i); + if (loss_scale > 0) { + float32x4_t loss_scale_vec = simd_set(loss_scale); + grad_4 = vdivq_f32(grad_4, loss_scale_vec); + } + float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i); + float32x4_t variance_4 = + simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i); + float32x4_t param_4 = simd_load_offset(_params, param_dtype, i); + if (_weight_decay > 0 && !_adamw_mode) { + grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4); + } + momentum_4 = vmulq_f32(momentum_4, betta1_4); + momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4); + variance_4 = vmulq_f32(variance_4, betta2_4); + grad_4 = vmulq_f32(grad_4, grad_4); + variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4); + grad_4 = vsqrtq_f32(variance_4); + grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt); + grad_4 = vdivq_f32(momentum_4, grad_4); + if (_weight_decay > 0 && _adamw_mode) { + param_4 = vfmaq_f32(param_4, param_4, weight_decay_4); + } + param_4 = vfmaq_f32(param_4, grad_4, step_size_4); + simd_store_offset(_params, param_dtype, param_4, i); + simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i); + simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i); + } + } +#endif + if (_param_size > rounded_size) { + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = scalar_load_offset(grads, grad_dtype, k); + if (loss_scale > 0) { + grad /= loss_scale; + } + float param = scalar_load_offset(_params, param_dtype, k); + float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k); + float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k); + if (_weight_decay > 0 && !_adamw_mode) { + grad = param * _weight_decay + grad; + } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { + param += w_decay * param; + } + param = grad * step_size + param; + + scalar_store_offset(_params, param_dtype, param, k); + scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k); + scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k); + } + } + } +} + +void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg, + void *_exp_avg_sq, size_t _param_size, + at::ScalarType param_dtype, + at::ScalarType grad_dtype, + at::ScalarType exp_avg_dtype, + at::ScalarType exp_avg_sq_dtype, float loss_scale) { + size_t rounded_size = 0; +#if defined(__aarch64__) + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4); +#endif + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + +#if defined(__aarch64__) + float32x4_t betta1_4 = simd_set(_betta1); + float32x4_t betta2_4 = simd_set(_betta2); + float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); + float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); + float32x4_t bias2_sqrt = simd_set(_bias_correction2); + float32x4_t eps_4 = simd_set(_eps); + float32x4_t step_size_4 = simd_set(step_size); + float32x4_t weight_decay_4; + if (_weight_decay > 0) { + weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); + } + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) { + float32x4_t grad_4[4]; + float32x4_t momentum_4[4]; + float32x4_t variance_4[4]; + float32x4_t param_4[4]; +#pragma unroll 4 + for (int j = 0; j < 4; j++) { + grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j); + if (loss_scale > 0) { + float32x4_t loss_scale_vec = simd_set(loss_scale); + grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec); + } + momentum_4[j] = + simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j); + variance_4[j] = + simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j); + param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j); + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4); + } + momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4); + momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4); + variance_4[j] = vmulq_f32(variance_4[j], betta2_4); + grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]); + variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4); + grad_4[j] = vsqrtq_f32(variance_4[j]); + grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt); + grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4); + } + param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4); + simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j], + i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j], + i + SIMD_WIDTH * j); + } + } + } +#endif + if (_param_size > rounded_size) { + Step_1(scalar_seek_offset(_params, param_dtype, rounded_size), + scalar_seek_offset(grads, grad_dtype, rounded_size), + scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size), + scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size), + (_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype, + exp_avg_sq_dtype, loss_scale); + } +} + +void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg, + void *_exp_avg_sq, size_t _param_size, + at::ScalarType param_dtype, + at::ScalarType grad_dtype, + at::ScalarType exp_avg_dtype, + at::ScalarType exp_avg_sq_dtype, float loss_scale) { + size_t rounded_size = 0; +#if defined(__aarch64__) + rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8); +#endif + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; +#if defined(__aarch64__) + float32x4_t betta1_4 = simd_set(_betta1); + float32x4_t betta2_4 = simd_set(_betta2); + float32x4_t betta1_minus1_4 = simd_set(betta1_minus1); + float32x4_t betta2_minus1_4 = simd_set(betta2_minus1); + float32x4_t bias2_sqrt = simd_set(_bias_correction2); + float32x4_t eps_4 = simd_set(_eps); + float32x4_t step_size_4 = simd_set(step_size); + float32x4_t weight_decay_4; + if (_weight_decay > 0) { + weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay); + } + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; + +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) { + float32x4_t grad_4[8]; + float32x4_t momentum_4[8]; + float32x4_t variance_4[8]; + float32x4_t param_4[8]; +#pragma unroll 4 + for (int j = 0; j < 8; j++) { + grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j); + if (loss_scale > 0) { + float32x4_t loss_scale_vec = simd_set(loss_scale); + grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec); + } + momentum_4[j] = + simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j); + variance_4[j] = + simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j); + param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j); + if (_weight_decay > 0 && !_adamw_mode) { + grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4); + } + momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4); + momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4); + variance_4[j] = vmulq_f32(variance_4[j], betta2_4); + grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]); + variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4); + grad_4[j] = vsqrtq_f32(variance_4[j]); + grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt); + grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]); + if (_weight_decay > 0 && _adamw_mode) { + param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4); + } + param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4); + simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j], + i + SIMD_WIDTH * j); + simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j], + i + SIMD_WIDTH * j); + } + } + } +#endif + if (_param_size > rounded_size) { + Step_4(scalar_seek_offset(_params, param_dtype, rounded_size), + scalar_seek_offset(grads, grad_dtype, rounded_size), + scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size), + scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size), + (_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype, + exp_avg_sq_dtype, loss_scale); + } +} + +void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2, + float epsilon, float weight_decay, + bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale) { + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + this->IncrementStep(step, beta1, beta2); + this->update_state(lr, epsilon, weight_decay, bias_correction); + this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(), + exp_avg_sq_c.data_ptr(), params_c.numel(), + params_c.scalar_type(), grads_c.scalar_type(), + exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale); +} + +namespace py = pybind11; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::class_(m, "CPUAdamOptimizer") + .def(py::init()) + .def("step", &AdamOptimizer::step); +} diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h new file mode 100644 index 000000000..c731850ed --- /dev/null +++ b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h @@ -0,0 +1,201 @@ +#pragma once +#include +#include + +#include + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) +#define TILE (128 * 1024 * 1024) + +#if defined(__aarch64__) +#include +#define SIMD_WIDTH 4 + +inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: { + auto ptr_f = reinterpret_cast(ptr); + return vld1q_f32(ptr_f + offset); + } + case at::ScalarType::Half: { + auto ptr_h = reinterpret_cast(ptr); + return vcvt_f32_f16(vld1_f16(ptr_h + offset)); + } + // case at::ScalarType::BFloat16: { + // auto ptr_b = reinterpret_cast(ptr); + // return vcvt_f32_bf16(vld1_bf16(ptr_b + offset)); + // } + default: + AT_ERROR("Unsupported dtype"); + break; + } +} +inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) { + return simd_load_offset(ptr, dtype, 0); +} + +inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: { + auto ptr_f = reinterpret_cast(ptr); + vst1q_f32(ptr_f + offset, data); + break; + } + case at::ScalarType::Half: { + auto ptr_h = reinterpret_cast(ptr); + vst1_f16(ptr_h + offset, vcvt_f16_f32(data)); + break; + } + // case at::ScalarType::BFloat16: { + // auto ptr_b = reinterpret_cast(ptr); + // vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data)); + // break; + // } + default: + AT_ERROR("Unsupported dtype"); + break; + } +} + +inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) { + return simd_store_offset(ptr, dtype, data, 0); +} + +inline float32x4_t simd_set(float value) { + auto val = static_cast(value); + return vdupq_n_f32(val); +} + +#endif + +inline float scalar_load_offset(const void *ptr, at::ScalarType dtype, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: + return *(reinterpret_cast(ptr) + offset); + case at::ScalarType::Half: + return static_cast( + *(reinterpret_cast(ptr) + offset)); + // case at::ScalarType::BFloat16: + // return static_cast( + // *(reinterpret_cast(ptr) + offset)); + default: + AT_ERROR("Unsupported dtype"); + break; + } +} + +inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: + *(reinterpret_cast(ptr) + offset) = data; + break; + case at::ScalarType::Half: + *(reinterpret_cast(ptr) + offset) = data; + break; + // case at::ScalarType::BFloat16: + // *(reinterpret_cast(ptr) + offset) = data; + break; + default: + AT_ERROR("Unsupported dtype"); + break; + } +} + +inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype, + size_t offset) { + switch (dtype) { + case at::ScalarType::Float: + return reinterpret_cast(ptr) + offset; + case at::ScalarType::Half: + return reinterpret_cast(ptr) + offset; + // case at::ScalarType::BFloat16: + // return reinterpret_cast(ptr) + offset; + default: + AT_ERROR("Unsupported dtype"); + break; + } +} +#define STEP(SPAN) \ + void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \ + void *_exp_avg_sq, size_t _param_size, \ + at::ScalarType param_dtype, at::ScalarType grad_dtype, \ + at::ScalarType exp_avg_dtype, \ + at::ScalarType exp_avg_sq_dtype, float loss_scale = -1); + +class AdamOptimizer { + private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + bool _adamw_mode; + + public: + AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, + float eps = 1e-8, float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _adamw_mode(adamw_mode) {} + ~AdamOptimizer() {} + + STEP(1) + STEP(4) + STEP(8) + inline void IncrementStep(size_t step, float beta1, float beta2) { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, + bool bias_correction) { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + + void step(size_t step, float lr, float beta1, float beta2, float epsilon, + float weight_decay, bool bias_correction, torch::Tensor ¶ms, + torch::Tensor &grads, torch::Tensor &exp_avg, + torch::Tensor &exp_avg_sq, float loss_scale); +}; diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py index fe31921b9..5f01e3ef3 100644 --- a/colossalai/kernel/cuda_native/mha/utils.py +++ b/colossalai/kernel/cuda_native/mha/utils.py @@ -5,7 +5,7 @@ import torch import torch.nn.functional as F from einops import rearrange -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device class Unpad(torch.autograd.Function): diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py index 4fc5040f6..5fd5602e7 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py @@ -12,7 +12,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.logging import get_dist_logger -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._base_schedule import BaseSchedule diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py index 867c3dfa8..4cd7e47c3 100644 --- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py @@ -9,7 +9,7 @@ import colossalai.legacy.communication.p2p_v2 as comm from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.engine import Engine -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._pipeline_schedule import PipelineSchedule diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py index 8304cd2e1..b6ec5347f 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py @@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..colossalai_layer._utils import ColossalaiModule diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py index 3b2e032e5..f81c5334a 100644 --- a/colossalai/legacy/nn/layer/parallel_2d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py @@ -18,7 +18,7 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py index fc2e35f36..b451a4031 100644 --- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py @@ -19,7 +19,7 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py index 196679994..16e515f87 100644 --- a/colossalai/legacy/nn/layer/parallel_3d/layers.py +++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py @@ -27,7 +27,7 @@ from colossalai.legacy.utils.checkpointing import ( partition_tensor_parallel_state_dict, ) from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ._operation import ( diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py index 12965a4a6..590ad5ff6 100644 --- a/colossalai/legacy/nn/layer/vanilla/layers.py +++ b/colossalai/legacy/nn/layer/vanilla/layers.py @@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter from colossalai.legacy.context import seed from colossalai.legacy.registry import LAYERS from colossalai.nn import init as init -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ..utils import to_2tuple diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py index 19f77d430..e336717f4 100644 --- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py +++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py @@ -3,7 +3,7 @@ import types from time import time from typing import List -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from .stateful_tensor import StatefulTensor, TensorState from .tensor_placement_policy import TensorPlacementPolicy diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index c3c0180e8..7d53a1dd6 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,9 +1,10 @@ import math +import platform from typing import Optional import torch -from colossalai.kernel.op_builder import CPUAdamBuilder +from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder from .nvme_optimizer import NVMeOptimizer @@ -77,7 +78,7 @@ class CPUAdam(NVMeOptimizer): default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) self.adamw_mode = adamw_mode - cpu_adam = CPUAdamBuilder().load() + cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load() # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index c7a309b87..d34fd601a 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -84,9 +84,10 @@ class HybridAdam(CPUAdam): nvme_offload_fraction, nvme_offload_dir, ) - fused_optim = FusedOptimBuilder().load() - self.gpu_adam_op = fused_optim.multi_tensor_adam - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) + if torch.cuda.is_available(): + fused_optim = FusedOptimBuilder().load() + self.gpu_adam_op = fused_optim.multi_tensor_adam + self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @torch.no_grad() def step(self, closure=None, div_scale: float = -1): @@ -118,11 +119,11 @@ class HybridAdam(CPUAdam): group_step = state["step"] beta1, beta2 = group["betas"] - if target_device.type == "cpu": - assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" - assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" + if target_device.type == "cpu" or target_device.type == "npu": + assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu" + assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu" self._pre_update(p, "exp_avg", "exp_avg_sq") - if p.grad.dtype is torch.bfloat16: + if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu": # cpu adam kernel does not support bf16 now bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index bdf122dc0..e1a2d38cd 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -10,7 +10,7 @@ from torch.utils._pytree import tree_map from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from .base import PipelineSchedule diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py index 780437155..cbf6dd80f 100644 --- a/colossalai/pipeline/schedule/interleaved_pp.py +++ b/colossalai/pipeline/schedule/interleaved_pp.py @@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from .base import PipelineSchedule diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 1f3b80857..fd918cf19 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.stage_manager import PipelineStageManager -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from ._utils import ( detach, diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index 42efe9a44..8387bb5e3 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -2,16 +2,19 @@ # -*- encoding: utf-8 -*- import warnings from abc import ABC, abstractmethod + import torch.nn as nn + from colossalai.lazy import LazyInitContext -from ._operation import hook_paramter_in_backward +from ._operation import hook_paramter_in_backward from .utils import SeqParallelUtils __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] try: from apex.contrib.layer_norm.layer_norm import FastLayerNorm + EnableFastLayerNorm = True except ImportError: EnableFastLayerNorm = False @@ -19,10 +22,27 @@ except ImportError: try: from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + + class FusedLayerNormWithHook(ApexFusedLayerNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight, self.bias) + return output + + class FusedRMSNormWithHook(ApexFusedRMSNorm): + def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): + super().__init__(normalized_shape, eps, elementwise_affine) + + def forward(self, input): + output = super().forward(input) + output = hook_paramter_in_backward(output, self.weight) + return output + except ImportError: - warnings.warn( - "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel" - ) + warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel") FAST_LAYERNORM_SUPPORTED_SIZE = [ 1024, @@ -52,6 +72,7 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [ ] if EnableFastLayerNorm: + class FastLayerNormWithHook(FastLayerNorm): def __init__(self, hidden_size, eps=0.00001): super().__init__(hidden_size, eps) @@ -60,25 +81,7 @@ if EnableFastLayerNorm: output = super().forward(input) output = hook_paramter_in_backward(output, self.weight, self.bias) return output - -class FusedLayerNormWithHook(ApexFusedLayerNorm): - def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): - super().__init__(normalized_shape, eps, elementwise_affine) - - def forward(self, input): - output = super().forward(input) - output = hook_paramter_in_backward(output, self.weight, self.bias) - return output - -class FusedRMSNormWithHook(ApexFusedRMSNorm): - def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True): - super().__init__(normalized_shape, eps, elementwise_affine) - - def forward(self, input): - output = super().forward(input) - output = hook_paramter_in_backward(output, self.weight) - return output - + class BaseLayerNorm(ABC): @abstractmethod @@ -244,12 +247,13 @@ class FusedRMSNorm(BaseLayerNorm): """ This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. """ + def __init__(self) -> None: raise NotImplementedError( "FusedRMSNorm is not implemented as a physical class. " "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex." ) - + @staticmethod def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: r""" @@ -264,7 +268,7 @@ class FusedRMSNorm(BaseLayerNorm): nn.Module: FusedRMSNorm module. """ try: - from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm + pass except ImportError: raise ImportError( "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" @@ -282,7 +286,9 @@ class FusedRMSNorm(BaseLayerNorm): eps = module.eps elementwise_affine = module.elementwise_affine - rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) + rmsnorm = FusedRMSNormWithHook( + normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine + ) rmsnorm.weight = module.weight diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index 3ec39b949..0246a35e2 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -7,7 +7,7 @@ from .common import ( is_ddp_ignored, set_seed, ) -from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize +from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize from .multi_tensor_apply import multi_tensor_applier from .tensor_detector import TensorDetector from .timer import MultiTimer, Timer @@ -29,4 +29,5 @@ __all__ = [ "set_seed", "is_ddp_ignored", "set_device", + "IS_NPU_AVAILABLE", ] diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py deleted file mode 100644 index 6bfb08d1f..000000000 --- a/colossalai/utils/cuda.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from typing import Optional - -import torch -import torch.distributed as dist - - -def set_to_cuda(models): - """Send model to gpu. - - :param models: nn.module or a list of module - """ - if isinstance(models, list) and len(models) > 1: - ret = [] - for model in models: - ret.append(model.to(get_current_device())) - return ret - elif isinstance(models, list): - return models[0].to(get_current_device()) - else: - return models.to(get_current_device()) - - -def get_current_device() -> torch.device: - """ - Returns currently selected device (gpu/cpu). - If cuda available, return gpu, otherwise return cpu. - """ - if torch.cuda.is_available(): - return torch.device(f"cuda:{torch.cuda.current_device()}") - else: - return torch.device("cpu") - - -def synchronize(): - """Similar to cuda.synchronize(). - Waits for all kernels in all streams on a CUDA device to complete. - """ - if torch.cuda.is_available(): - torch.cuda.synchronize() - - -def empty_cache(): - """Similar to cuda.empty_cache() - Releases all unoccupied cached memory currently held by the caching allocator. - """ - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - -def set_device(index: Optional[int] = None) -> None: - if index is None: - index = dist.get_rank() % torch.cuda.device_count() - torch.cuda.set_device(index) diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py new file mode 100644 index 000000000..e1bd20d59 --- /dev/null +++ b/colossalai/utils/device.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist + +IS_NPU_AVAILABLE: bool = False +try: + import torch_npu # noqa + + IS_NPU_AVAILABLE = torch.npu.is_available() +except ImportError: + pass + + +def set_to_cuda(models): + """Send model to gpu. + + :param models: nn.module or a list of module + """ + if isinstance(models, list) and len(models) > 1: + ret = [] + for model in models: + ret.append(model.to(get_current_device())) + return ret + elif isinstance(models, list): + return models[0].to(get_current_device()) + else: + return models.to(get_current_device()) + + +def get_current_device() -> torch.device: + """ + Returns currently selected device (gpu/cpu). + If cuda available, return gpu, otherwise return cpu. + """ + if torch.cuda.is_available(): + return torch.device(f"cuda:{torch.cuda.current_device()}") + elif IS_NPU_AVAILABLE: + return torch.device(f"npu:{torch.npu.current_device()}") + else: + return torch.device("cpu") + + +def _dispatch_device_func(fn_name: str, *args, **kwargs): + if torch.cuda.is_available(): + return getattr(torch.cuda, fn_name)(*args, **kwargs) + elif IS_NPU_AVAILABLE: + return getattr(torch.npu, fn_name)(*args, **kwargs) + else: + raise RuntimeError("No device available") + + +# device semantics + + +def can_device_access_peer(device, peer_device) -> bool: + return _dispatch_device_func("can_device_access_peer", device, peer_device) + + +def current_device() -> int: + return _dispatch_device_func("current_device") + + +def current_stream(device=None): + return _dispatch_device_func("current_stream", device) + + +def default_stream(device=None): + return _dispatch_device_func("default_stream", device) + + +def device_count() -> int: + return _dispatch_device_func("device_count") + + +def get_device_capability(device=None) -> Tuple[int, int]: + return _dispatch_device_func("get_device_capability", device) + + +def get_device_name(device=None) -> str: + return _dispatch_device_func("get_device_name", device) + + +def get_device_properties(device): + return _dispatch_device_func("get_device_properties", device) + + +def set_device(index: Optional[int] = None) -> None: + if index is None: + index = dist.get_rank() % device_count() + _dispatch_device_func("set_device", index) + + +def set_stream(stream_): + return _dispatch_device_func("set_stream", stream_) + + +def stream(stream_): + return _dispatch_device_func("stream", stream_) + + +def synchronize(): + return _dispatch_device_func("synchronize") + + +def utilization(device=None) -> int: + return _dispatch_device_func("utilization", device) + + +# random number generator + + +def get_rng_state(device="cuda") -> torch.Tensor: + return _dispatch_device_func("get_rng_state", device) + + +def get_rng_state_all() -> List[torch.Tensor]: + return _dispatch_device_func("get_rng_state_all") + + +def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None: + return _dispatch_device_func("set_rng_state", new_state, device) + + +def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None: + return _dispatch_device_func("set_rng_state_all", new_states) + + +def manual_seed(seed: int) -> None: + return _dispatch_device_func("manual_seed", seed) + + +def manual_seed_all(seed: int) -> None: + return _dispatch_device_func("manual_seed_all", seed) + + +def seed() -> None: + return _dispatch_device_func("seed") + + +def seed_all() -> None: + return _dispatch_device_func("seed_all") + + +def initial_seed() -> int: + return _dispatch_device_func("initial_seed") + + +# streams and events + + +def Stream(device=None, priority=0, **kwargs): + return _dispatch_device_func("Stream", device, priority, **kwargs) + + +def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False): + return _dispatch_device_func("Event", enable_timing, blocking, interprocess) + + +# memory management + + +def empty_cache() -> None: + return _dispatch_device_func("empty_cache") + + +def memory_stats(device=None) -> Dict[str, Any]: + return _dispatch_device_func("memory_stats", device) + + +def memory_summary(device=None, abbreviated=False) -> str: + return _dispatch_device_func("memory_summary", device, abbreviated) + + +def memory_snapshot(): + return _dispatch_device_func("memory_snapshot") + + +def memory_allocated(device=None) -> int: + return _dispatch_device_func("memory_allocated", device) + + +def max_memory_allocated(device=None) -> int: + return _dispatch_device_func("max_memory_allocated", device) + + +def reset_max_memory_allocated(device=None) -> None: + return _dispatch_device_func("reset_max_memory_allocated", device) + + +def memory_reserved(device=None) -> int: + return _dispatch_device_func("memory_reserved", device) + + +def max_memory_reserved(device=None) -> int: + return _dispatch_device_func("max_memory_reserved", device) + + +def set_per_process_memory_fraction(fraction: float, device=None) -> None: + return _dispatch_device_func("set_per_process_memory_fraction", fraction, device) + + +def reset_peak_memory_stats(device=None) -> None: + return _dispatch_device_func("reset_peak_memory_stats", device) diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 2f61817f0..8ab6b46f2 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -3,7 +3,7 @@ import time from typing import Tuple -from .cuda import synchronize +from .device import synchronize class Timer: diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py index 42a8cdbb3..ff92ab89d 100644 --- a/colossalai/zero/gemini/chunk/chunk.py +++ b/colossalai/zero/gemini/chunk/chunk.py @@ -7,6 +7,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from colossalai.utils import get_current_device +from colossalai.utils.device import IS_NPU_AVAILABLE class TensorState(Enum): @@ -172,7 +173,7 @@ class Chunk: if self.chunk_temp is not None: # this chunk is not closed - if self.chunk_temp.device.type == "cuda": + if self.chunk_temp.device.type == "cuda" or self.chunk_temp.device.type == "npu": cuda_memory += self.chunk_mem else: cpu_memory += self.chunk_mem @@ -191,10 +192,8 @@ class Chunk: if self.chunk_temp is not None: return self.chunk_temp.device.type else: - if self.is_gathered: - return "cuda" - elif self.cuda_shard is not None: - return "cuda" + if self.is_gathered or self.cuda_shard is not None: + return "npu" if IS_NPU_AVAILABLE else "cuda" else: return "cpu" @@ -329,12 +328,12 @@ class Chunk: # when the current chunk is not synchronized with the optimizer # just use another way for the movement if not self.optim_sync_flag: - assert device.type == "cuda", "each chunk should first be moved to CUDA" + assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA" self.__paired_shard_move() self.optim_sync_flag = True return - if device.type == "cuda": + if device.type == "cuda" or device.type == "npu": assert device == get_current_device(), "can't move chunk to another device" if self.cuda_shard: @@ -484,7 +483,7 @@ class Chunk: assert friend_chunk.is_gathered is True self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) self.optim_sync_flag = True - elif friend_chunk.device_type == "cuda" and self.device_type == "cuda": + elif friend_chunk.device_type in ("cuda", "npu") and self.device_type in ("cuda", "npu"): self.cuda_shard.copy_(friend_chunk.cuda_shard) self.optim_sync_flag = True self.cpu_vis_flag = False diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py index 5ad622a13..974943747 100644 --- a/colossalai/zero/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -206,7 +206,10 @@ class ChunkManager: tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. """ assert tensor not in self.tensor_chunk_map - self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() + device_type = tensor.device.type + if device_type == "npu": + device_type = "cuda" + self.total_mem[device_type] += tensor.numel() * tensor.element_size() def __repr__(self) -> str: msg = [ diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index ff943f4b4..0b70ec742 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -10,32 +10,30 @@ import torch.nn as nn from torch.distributed import ProcessGroup from torch.distributed.distributed_c10d import _get_default_group -from colossalai.checkpoint_io.utils import StateDictSharder +from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger from colossalai.tensor.colo_parameter import ColoParameter -from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored -from colossalai.checkpoint_io.utils import gather_distributed_param - -from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager -from .gemini_hook import GeminiZeROHook -from .gemini_mgr import GeminiManager -from .memory_tracer import MemStats, OrderedParamGenerator -from .utils import get_temp_total_chunk_on_cuda - from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, - init_tensor_as_customization_distributed, get_device_mesh, + get_global_shape, get_sharding_spec, + init_as_dtensor, + init_tensor_as_customization_distributed, is_customized_distributed_tensor, is_distributed_tensor, - get_global_shape, - init_as_dtensor ) +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored + +from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager +from .gemini_hook import GeminiZeROHook +from .gemini_mgr import GeminiManager +from .memory_tracer import MemStats, OrderedParamGenerator +from .utils import get_temp_total_chunk_on_cuda try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys @@ -162,7 +160,7 @@ class GeminiDDP(ModelWrapper): self._init_chunks( param_order=param_order, strict_ddp_mode=strict_ddp_mode, - cpu_offload=self.gemini_manager.policy_name != "cuda", + cpu_offload=not (self.gemini_manager.policy_name == "static" and offload_param_frac == 0), pin_memory=pin_memory, ) super().__init__(module) @@ -453,12 +451,13 @@ class GeminiDDP(ModelWrapper): global_shape = get_global_shape(tensor) device_mesh = get_device_mesh(tensor) shard_spec = get_sharding_spec(tensor) - record_tensor = init_as_dtensor(record_tensor, - device_mesh=device_mesh, - sharding_spec=shard_spec, - global_shape = global_shape) + record_tensor = init_as_dtensor( + record_tensor, device_mesh=device_mesh, sharding_spec=shard_spec, global_shape=global_shape + ) elif is_customized_distributed_tensor(tensor): - init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn) + init_tensor_as_customization_distributed( + record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn + ) record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() assert tensor not in chunk_to_save_data @@ -634,7 +633,15 @@ class GeminiDDP(ModelWrapper): local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) local_state = {k: v for k, v in local_name_params if v is not None} - def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None): + def load( + param_name, + dest_tensor, + copy_func, + source_device_mesh=None, + source_sharding_spec=None, + shard_fn=None, + gather_fn=None, + ): state_key = prefix + param_name if state_key in state_dict: input_param = state_dict[state_key] @@ -642,7 +649,9 @@ class GeminiDDP(ModelWrapper): if source_device_mesh is not None and source_sharding_spec is not None: input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) elif shard_fn is not None and gather_fn is not None: - input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn) + input_param = distribute_tensor_with_customization( + input_param, shard_fn=shard_fn, gather_fn=gather_fn + ) # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: @@ -687,7 +696,6 @@ class GeminiDDP(ModelWrapper): temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) for tensor, tensor_info in chunk.tensors_info.items(): - source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None if is_distributed_tensor(tensor): # shard the input param @@ -699,7 +707,15 @@ class GeminiDDP(ModelWrapper): parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] - load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn) + load( + parameter_name, + tensor, + partial(load_parameter, parameter_slice), + source_device_mesh, + source_sharding_spec, + shard_fn, + gather_fn, + ) if chunk.is_gathered: chunk.cuda_global_chunk.copy_(temp_chunk) @@ -799,7 +815,7 @@ class GeminiDDP(ModelWrapper): for buffer in self.module.buffers(): if isinstance(buffer, LazyTensor): buffer.materialize() - buffer.data = buffer.cuda() + buffer.data = buffer.to(get_current_device()) if torch.is_floating_point(buffer): buffer.data = buffer.to(self.mixed_precision) diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py index f7ff3f6cd..150932e3d 100644 --- a/colossalai/zero/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -17,9 +17,7 @@ class GeminiManager: https://arxiv.org/abs/2108.05818 Args: - placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. - If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. - If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used. + placement_policy (str): Which device to place *held* tensors. It can be 'static' and 'auto'. If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. Note that 'auto' policy can only work well when no other processes use CUDA during your training. chunk_manager (ChunkManager): A ``ChunkManager`` instance. @@ -121,7 +119,7 @@ class GeminiManager: start = time() cuda_demand = 0 for chunk in chunks: - if chunk.device_type == "cuda": + if chunk.device_type == "cuda" or chunk.device_type == "npu": if chunk.is_gathered: pass else: diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index e20d846f1..50d4f51d3 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -7,31 +7,29 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union import torch import torch.distributed as dist from packaging.version import Version +from torch.distributed import ProcessGroup from torch.nn import Parameter from torch.optim import Optimizer -from torch.distributed import ProcessGroup from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import StateDictSharder +from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam -from colossalai.utils import disposable, get_current_device, is_ddp_ignored - -from .chunk import Chunk, ChunkManager -from .gemini_ddp import GeminiDDP -from colossalai.checkpoint_io.utils import gather_distributed_param from colossalai.tensor.d_tensor import ( distribute_tensor, distribute_tensor_with_customization, - init_tensor_as_customization_distributed, get_device_mesh, get_sharding_spec, + init_as_dtensor, + init_tensor_as_customization_distributed, is_customized_distributed_tensor, is_distributed_tensor, - get_global_shape, - init_as_dtensor ) +from colossalai.utils import disposable, get_current_device, is_ddp_ignored + +from .chunk import Chunk, ChunkManager +from .gemini_ddp import GeminiDDP __all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] @@ -312,7 +310,7 @@ class GeminiOptimizer(OptimizerWrapper): chunk16 = self.param_to_chunk16[fake_param] chunk32 = chunk16.paired_chunk - if chunk32.device_type == "cuda": + if chunk32.device_type == "cuda" or chunk32.device_type == "npu": continue if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: @@ -326,7 +324,7 @@ class GeminiOptimizer(OptimizerWrapper): for fake_param in group["params"]: chunk16 = self.param_to_chunk16[fake_param] chunk32 = chunk16.paired_chunk - if chunk32.device_type == "cuda": + if chunk32.device_type == "cuda" or chunk32.device_type == "npu": state = self.optim.state[fake_param] for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -479,15 +477,19 @@ class GeminiOptimizer(OptimizerWrapper): state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() if is_dtensor: state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) - state_tensor = init_as_dtensor(state_tensor, - device_mesh=device_mesh, - sharding_spec=shard_spec, - global_shape = global_shape) + state_tensor = init_as_dtensor( + state_tensor, + device_mesh=device_mesh, + sharding_spec=shard_spec, + global_shape=global_shape, + ) elif is_customized_distributed: state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) - init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) + init_tensor_as_customization_distributed( + state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn + ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() - + collected_states[state_name] = state_tensor.reshape(global_shape) return collected_states @@ -533,13 +535,14 @@ class GeminiOptimizer(OptimizerWrapper): collected_states[state_name] = torch.reshape(state_tensor, param.shape) if is_dtensor: state_tensor = state_tensor.to(param.device) - state_tensor = init_as_dtensor(state_tensor, - sharding_spec=shard_spec, - device_mesh=device_mesh, - global_shape=global_shape) + state_tensor = init_as_dtensor( + state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape + ) elif is_customized_distributed: state_tensor = state_tensor.to(param.device) - init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) + init_tensor_as_customization_distributed( + state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn + ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() return collected_states @@ -548,7 +551,7 @@ class GeminiOptimizer(OptimizerWrapper): self, param_id: int, state_names: list, - device: torch.device = torch.device("cuda"), + device: torch.device = get_current_device(), dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ @@ -705,7 +708,7 @@ class GeminiOptimizer(OptimizerWrapper): ret_val = torch.zeros( state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False ) - + if is_dtensor: value = torch.reshape(value, global_shape) value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index d61082bed..c1b35ee17 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -12,6 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch.distributed import ProcessGroup from torch.optim import Optimizer +import colossalai.utils.device as device_utils from colossalai.amp.naive_amp.mixed_precision_mixin import ( BF16MixedPrecisionMixin, FP16MixedPrecisionMixin, @@ -22,7 +23,7 @@ from colossalai.logging import get_dist_logger from colossalai.tensor.moe_tensor.api import is_moe_tensor # from colossalai.tensor import ColoParameter, ProcessGroup -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from .bookkeeping import BucketStore, GradientStore, ParameterStore @@ -182,7 +183,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # intialize communication stream for # communication-compuation overlapping if self._overlap_communication: - self._comm_stream = torch.cuda.Stream() + self._comm_stream = device_utils.Stream() # reduction hook is only used if overlapping communication # or stage 2 is used @@ -216,7 +217,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): return len(self._working_param_groups) def _sanity_checks(self): - assert torch.cuda.is_available(), "CUDA is required" + assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required" for param_group in self.optim.param_groups: group_params = param_group["params"] for param in group_params: @@ -339,11 +340,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if len(moe_grad_list) > 0: moe_flat_grads.record_stream(stream) # waiting for ops in the default stream finishing - stream.wait_stream(torch.cuda.current_stream()) + stream.wait_stream(device_utils.current_stream()) else: - stream = torch.cuda.current_stream() + stream = device_utils.current_stream() - with torch.cuda.stream(stream): + with device_utils.stream(stream): group_id = self._bucket_store.current_group_id if self.moe_extra_dp_pg is None: @@ -485,7 +486,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # clear reduced grads if self._overlap_communication: - torch.cuda.synchronize() + device_utils.synchronize() self.zero_grad() @@ -504,7 +505,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): # clear reduced grads if self._overlap_communication: - torch.cuda.synchronize() + device_utils.synchronize() self.zero_grad() @@ -620,22 +621,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper): release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank + device = get_current_device() for group_id in range(self.num_param_groups): master_working_param = self.optim.param_groups[group_id]["params"] for idx, splited_param in enumerate(master_working_param): working_param = real_working_params[group_id][idx] if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) + torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self.moe_extra_dp_pg_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg) + dist.all_gather( + all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg + ) else: all_splited_param = [ - torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) + torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size) ] - dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) + dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] @@ -657,7 +661,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): norm_type = float(norm_type) if norm_type == inf: total_norm = max(grad.data.abs().max() for grad in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) total_norm = total_norm_cuda.item() @@ -668,7 +672,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper): total_norm_exponentiated += grad_norm_exponentiated # Sum across all model parallel GPUs. - total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) + total_norm_exponentiated_cuda = torch.tensor( + [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float + ) torch.distributed.all_reduce( total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg ) @@ -759,6 +765,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): Dict: the pytorch form state_dict """ zero_state = dict() + device = get_current_device() for param, state in self.optim.state.items(): zero_state[param] = copy.deepcopy(state) for k, v in state.items(): @@ -766,14 +773,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): working_param = self._param_store.master_to_working_param[id(param)] if self.moe_extra_dp_pg is not None and is_moe_tensor(v): gather_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) ] - dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg) + dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg) else: gather_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) ] - dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) + dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg) param_state = ( torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) @@ -820,6 +827,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper): ret_block = dict() ret_block_size = 0 + device = get_current_device() local_states = self.optim.state_dict()["state"] for param_idx, states in local_states.items(): current_block_size = 0 @@ -836,14 +844,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if isinstance(v, torch.Tensor) and k != "step": if self.moe_extra_dp_pg is not None and is_moe_tensor(v): state_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) ] - dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg) + dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg) else: state_tensor = [ - torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) + torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size) ] - dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) + dist.all_gather(state_tensor, v.to(device), group=self.dp_pg) state_tensor = ( torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() ) diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index b38ddbb4a..47fc9e2a7 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -13,6 +13,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaForCausalLM import colossalai +import colossalai.utils.device as device_utils from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.cluster import DistCoordinator @@ -194,7 +195,7 @@ def main(): torch.set_default_dtype(torch.bfloat16) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) torch.set_default_dtype(torch.float) - coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master( f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" ) @@ -220,7 +221,7 @@ def main(): performance_evaluator.on_step_end(**batch) performance_evaluator.on_fit_end() - coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") + coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB") if __name__ == "__main__": diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py index 05e71edf1..4bea5c81a 100644 --- a/examples/language/llama2/performance_evaluator.py +++ b/examples/language/llama2/performance_evaluator.py @@ -5,7 +5,9 @@ import torch import torch.distributed as dist from torch import Tensor +import colossalai.utils.device as device_utils from colossalai.cluster import DistCoordinator +from colossalai.utils.device import get_current_device def divide(x: float, y: float) -> float: @@ -20,7 +22,7 @@ def divide(x: float, y: float) -> float: def all_reduce_mean(x: float, world_size: int) -> float: if world_size == 1: return x - tensor = torch.tensor([x], device=torch.cuda.current_device()) + tensor = torch.tensor([x], device=get_current_device()) dist.all_reduce(tensor) tensor = tensor / world_size return tensor.item() @@ -84,13 +86,13 @@ class PerformanceEvaluator: self.disable = self.ignore_steps > 0 and step < self.ignore_steps if self.disable: return - torch.cuda.synchronize() + device_utils.synchronize() self.timer.start() def on_step_end(self, input_ids: Tensor, **kwargs) -> None: if self.disable: return - torch.cuda.synchronize() + device_utils.synchronize() self.timer.end() batch_size, seq_len = input_ids.shape diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 808559ec9..21e216437 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -1,3 +1,4 @@ +from .arm_cpu_adam import ArmCPUAdamBuilder from .cpu_adam import CPUAdamBuilder from .fused_optim import FusedOptimBuilder from .layernorm import LayerNormBuilder @@ -29,4 +30,5 @@ __all__ = [ "MultiTensorLambBuilder", "MultiTensorScaleBuilder", "MultiTensorL2NormBuilder", + "ArmCPUAdamBuilder", ] diff --git a/op_builder/arm_cpu_adam.py b/op_builder/arm_cpu_adam.py new file mode 100644 index 000000000..18dd519fa --- /dev/null +++ b/op_builder/arm_cpu_adam.py @@ -0,0 +1,34 @@ +from .builder import Builder + + +class ArmCPUAdamBuilder(Builder): + NAME = "arm_cpu_adam" + PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam" + ext_type = "cpu" + + def __init__(self): + super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH) + self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] + + # necessary 4 functions + def sources_files(self): + ret = [ + self.csrc_abs_path("cpu_adam_arm.cpp"), + ] + return ret + + def include_dirs(self): + return [self.csrc_abs_path("includes")] + + def cxx_flags(self): + extra_cxx_flags = [ + "-std=c++14", + "-std=c++17", + "-g", + "-Wno-reorder", + "-fopenmp", + ] + return ["-O3"] + self.version_dependent_macros + extra_cxx_flags + + def nvcc_flags(self): + return [] diff --git a/op_builder/builder.py b/op_builder/builder.py index 75823ef10..d804cb160 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -7,7 +7,7 @@ import os import time from abc import ABC, abstractmethod from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Union from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 @@ -21,6 +21,8 @@ class Builder(ABC): prebuilt_import_path (str): the path where the extension is installed during pip install """ + ext_type: str = "cuda" + def __init__(self, name: str, prebuilt_import_path: str): self.name = name self.prebuilt_import_path = prebuilt_import_path @@ -165,7 +167,8 @@ class Builder(ABC): ) except ImportError: # check environment - self.check_runtime_build_environment() + if self.ext_type == "cuda": + self.check_runtime_build_environment() # time the kernel compilation start_build = time.time() @@ -208,11 +211,19 @@ class Builder(ABC): return op_module - def builder(self) -> "CUDAExtension": + def builder(self) -> Union["CUDAExtension", "CppExtension"]: """ get a CUDAExtension instance used for setup.py """ - from torch.utils.cpp_extension import CUDAExtension + from torch.utils.cpp_extension import CppExtension, CUDAExtension + + if self.ext_type == "cpp": + return CppExtension( + name=self.prebuilt_import_path, + sources=self.strip_empty_entries(self.sources_files()), + include_dirs=self.strip_empty_entries(self.include_dirs()), + extra_compile_args=self.strip_empty_entries(self.cxx_flags()), + ) return CUDAExtension( name=self.prebuilt_import_path, diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index 104ca254c..3eaaf882c 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -2,11 +2,14 @@ from typing import Optional import torch import torch.distributed as dist +from torch.optim import Adam import colossalai +import colossalai.utils.device as device_utils from colossalai.booster import Booster from colossalai.booster.plugin import LowLevelZeroPlugin -from colossalai.nn.optimizer import HybridAdam + +# from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -19,16 +22,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: + device = device_utils.get_current_device() try: plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) booster = Booster(plugin=plugin) model = model_fn() - optimizer = HybridAdam(model.parameters(), lr=1e-3) + optimizer = Adam(model.parameters(), lr=1e-3) criterion = lambda x: x.mean() data = data_gen_fn() data = { - k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() + k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() } model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) @@ -65,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): continue err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) - torch.cuda.empty_cache() + device_utils.empty_cache() if err is None: passed_models.append(name) @@ -89,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True): @rerun_if_address_is_in_use() def test_low_level_zero_plugin(early_stop: bool = True): - spawn(run_dist, 4, early_stop=early_stop) + spawn(run_dist, 2, early_stop=early_stop) if __name__ == "__main__": diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py index 9416ac86e..9df7cf75a 100644 --- a/tests/test_legacy/test_utils/test_memory.py +++ b/tests/test_legacy/test_utils/test_memory.py @@ -3,7 +3,7 @@ import pytest import colossalai from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.testing import spawn -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py index b8d3f45e0..21afff753 100644 --- a/tests/test_zero/test_gemini/test_fwd_bwd.py +++ b/tests/test_zero/test_gemini/test_fwd_bwd.py @@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py index bfd3ebfcb..35323e516 100644 --- a/tests/test_zero/test_gemini/test_grad_accum.py +++ b/tests/test_zero/test_gemini/test_grad_accum.py @@ -9,7 +9,7 @@ import colossalai from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py index e20428b67..152bf2895 100644 --- a/tests/test_zero/test_gemini/test_inference.py +++ b/tests/test_zero/test_gemini/test_inference.py @@ -11,7 +11,7 @@ from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 887e495e6..405d7d789 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp from colossalai.nn.optimizer import HybridAdam from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.utils import set_seed -from colossalai.utils.cuda import get_current_device +from colossalai.utils.device import get_current_device from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.chunk import search_chunk_configuration from tests.kit.model_zoo import model_zoo, run_fwd_bwd diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 3c5baea13..351ae5f67 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -9,7 +9,7 @@ from torch.testing import assert_close import colossalai from colossalai.testing import spawn from colossalai.testing.random import seed_all -from colossalai.utils import conditional_context +from colossalai.utils import conditional_context, get_current_device from colossalai.zero import LowLevelZeroOptimizer @@ -28,9 +28,9 @@ class MlpModel(nn.Module): def exam_zero_1_2_grad_acc(): local_rank = torch.distributed.get_rank() seed_all(2009) - + device = get_current_device() # create model - zero1_model = MlpModel().cuda() + zero1_model = MlpModel().to(device) zero2_model = copy.deepcopy(zero1_model) # create optimizer zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) @@ -43,8 +43,8 @@ def exam_zero_1_2_grad_acc(): ) # create data seed_all(2021 + local_rank) - input_data1 = torch.randn(32, 128).cuda() - input_data2 = torch.randn(32, 128).cuda() + input_data1 = torch.randn(32, 128, device=device) + input_data2 = torch.randn(32, 128, device=device) def fwd_bwd_func(number, cur_data, check_flag): # zero-dp forward @@ -71,14 +71,15 @@ def exam_zero_1_2_grad_acc(): def exam_zero_1_grad_acc(sync): local_rank = torch.distributed.get_rank() seed_all(2008) + device = get_current_device() # create models zero_model = MlpModel() torch_model = copy.deepcopy(zero_model) seed_all(2008) - zero_model = zero_model.cuda() - torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + zero_model = zero_model.to(device) + torch_model = DDP(torch_model.to(device), bucket_cap_mb=0) # create optimizer zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) @@ -94,8 +95,8 @@ def exam_zero_1_grad_acc(sync): # create data seed_all(2022 + local_rank) - input_data1 = torch.randn(32, 128).cuda() - input_data2 = torch.randn(32, 128).cuda() + input_data1 = torch.randn(32, 128, device=device) + input_data2 = torch.randn(32, 128, device=device) def fwd_bwd_func(no_sync, cur_data, check_flag): # zero1 fwd and bwd