Browse Source

[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
pull/5072/head
Hongxin Liu 1 year ago committed by GitHub
parent
commit
e5ce4c8ea6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
  2. 8
      colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
  3. 2
      colossalai/auto_parallel/offload/solver.py
  4. 10
      colossalai/booster/plugin/gemini_plugin.py
  5. 2
      colossalai/booster/plugin/low_level_zero_plugin.py
  6. 7
      colossalai/initialize.py
  7. 2
      colossalai/kernel/cuda_native/csrc/cpu_adam.h
  8. 304
      colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp
  9. 201
      colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h
  10. 2
      colossalai/kernel/cuda_native/mha/utils.py
  11. 2
      colossalai/legacy/engine/schedule/_pipeline_schedule.py
  12. 2
      colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
  13. 2
      colossalai/legacy/nn/layer/parallel_1d/layers.py
  14. 2
      colossalai/legacy/nn/layer/parallel_2d/layers.py
  15. 2
      colossalai/legacy/nn/layer/parallel_2p5d/layers.py
  16. 2
      colossalai/legacy/nn/layer/parallel_3d/layers.py
  17. 2
      colossalai/legacy/nn/layer/vanilla/layers.py
  18. 2
      colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
  19. 5
      colossalai/nn/optimizer/cpu_adam.py
  20. 15
      colossalai/nn/optimizer/hybrid_adam.py
  21. 2
      colossalai/pipeline/schedule/generate.py
  22. 2
      colossalai/pipeline/schedule/interleaved_pp.py
  23. 2
      colossalai/pipeline/schedule/one_f_one_b.py
  24. 58
      colossalai/shardformer/layer/normalization.py
  25. 3
      colossalai/utils/__init__.py
  26. 56
      colossalai/utils/cuda.py
  27. 207
      colossalai/utils/device.py
  28. 2
      colossalai/utils/timer.py
  29. 15
      colossalai/zero/gemini/chunk/chunk.py
  30. 5
      colossalai/zero/gemini/chunk/manager.py
  31. 66
      colossalai/zero/gemini/gemini_ddp.py
  32. 6
      colossalai/zero/gemini/gemini_mgr.py
  33. 53
      colossalai/zero/gemini/gemini_optimizer.py
  34. 52
      colossalai/zero/low_level/low_level_optim.py
  35. 5
      examples/language/llama2/benchmark.py
  36. 8
      examples/language/llama2/performance_evaluator.py
  37. 2
      op_builder/__init__.py
  38. 34
      op_builder/arm_cpu_adam.py
  39. 19
      op_builder/builder.py
  40. 14
      tests/test_booster/test_plugin/test_low_level_zero_plugin.py
  41. 2
      tests/test_legacy/test_utils/test_memory.py
  42. 2
      tests/test_zero/test_gemini/test_fwd_bwd.py
  43. 2
      tests/test_zero/test_gemini/test_grad_accum.py
  44. 2
      tests/test_zero/test_gemini/test_inference.py
  45. 2
      tests/test_zero/test_gemini/test_optim.py
  46. 19
      tests/test_zero/test_low_level/test_grad_acc.py

3
colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py

@ -8,6 +8,7 @@ import torch
from torch import Tensor from torch import Tensor
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
__all__ = ["BaseGradScaler"] __all__ = ["BaseGradScaler"]
@ -22,7 +23,7 @@ class BaseGradScaler(ABC):
def __init__(self, initial_scale: float, verbose: bool): def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0 assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale]) self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
self._verbose = verbose self._verbose = verbose
if self._verbose: if self._verbose:

8
colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py

@ -5,6 +5,8 @@ from typing import Optional
import torch import torch
from colossalai.utils.device import get_current_device
from .base_grad_scaler import BaseGradScaler from .base_grad_scaler import BaseGradScaler
__all__ = ["DynamicGradScaler"] __all__ = ["DynamicGradScaler"]
@ -37,12 +39,12 @@ class DynamicGradScaler(BaseGradScaler):
): ):
super().__init__(initial_scale, verbose) super().__init__(initial_scale, verbose)
if min_scale: if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale]) self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
else: else:
self._min_scale = None self._min_scale = None
if max_scale: if max_scale:
self._max_scale = torch.cuda.FloatTensor([max_scale]) self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
else: else:
self._max_scale = None self._max_scale = None
@ -115,7 +117,7 @@ class DynamicGradScaler(BaseGradScaler):
return state_dict return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self._scale = state_dict["scale"].cuda(torch.cuda.current_device()) self._scale = state_dict["scale"].to(get_current_device())
self._growth_factor = state_dict["growth_factor"] self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"] self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"] self._hysteresis = state_dict["hysteresis"]

2
colossalai/auto_parallel/offload/solver.py

@ -11,7 +11,7 @@ except:
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from .region import Region from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator

10
colossalai/booster/plugin/gemini_plugin.py

@ -25,6 +25,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
@ -37,6 +38,7 @@ PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
def get_param_info(optim: Optimizer): def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes: # Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape. # 1. A mapping from integer param_id to param32 shape.
@ -53,6 +55,8 @@ def get_param_info(optim: Optimizer):
start_index += len(group["params"]) start_index += len(group["params"])
return param_info return param_info
class GeminiCheckpointIO(GeneralCheckpointIO): class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -359,6 +363,8 @@ class GeminiPlugin(DPPluginBase):
) -> None: ) -> None:
super().__init__() super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported" assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
if IS_NPU_AVAILABLE:
assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict( self.gemini_config = dict(
chunk_config_dict=chunk_config_dict, chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()), chunk_init_device=(chunk_init_device or get_current_device()),
@ -437,7 +443,7 @@ class GeminiPlugin(DPPluginBase):
return True return True
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ["cuda"] return ["cuda", "npu"]
def configure( def configure(
self, self,
@ -485,4 +491,4 @@ class GeminiPlugin(DPPluginBase):
return GeminiCheckpointIO() return GeminiCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError raise NotImplementedError

2
colossalai/booster/plugin/low_level_zero_plugin.py

@ -306,7 +306,7 @@ class LowLevelZeroPlugin(DPPluginBase):
return True return True
def supported_devices(self) -> List[str]: def supported_devices(self) -> List[str]:
return ["cuda"] return ["cuda", "npu"]
def configure( def configure(
self, self,

7
colossalai/initialize.py

@ -11,7 +11,7 @@ import torch.distributed as dist
from colossalai.context import Config from colossalai.context import Config
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import set_device, set_seed from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
def launch( def launch(
@ -47,12 +47,15 @@ def launch(
if rank == 0: if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.") warnings.warn("`config` is deprecated and will be removed soon.")
if IS_NPU_AVAILABLE and backend == "nccl":
backend = "hccl"
# init default process group # init default process group
init_method = f"tcp://[{host}]:{port}" init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device # set cuda device
if torch.cuda.is_available(): if torch.cuda.is_available() or IS_NPU_AVAILABLE:
# if local rank is not given, calculate automatically # if local rank is not given, calculate automatically
set_device(local_rank) set_device(local_rank)

2
colossalai/kernel/cuda_native/csrc/cpu_adam.h

@ -142,6 +142,7 @@ class Adam_Optimizer {
} }
} }
#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
inline void simd_load(bool is_half, float *ptr, __half *h_ptr, inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) { AVX_Data &data) {
if (is_half) { if (is_half) {
@ -159,6 +160,7 @@ class Adam_Optimizer {
SIMD_STORE(ptr, data.data); SIMD_STORE(ptr, data.data);
} }
} }
#endif
void step(size_t step, float lr, float beta1, float beta2, float epsilon, void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params, float weight_decay, bool bias_correction, torch::Tensor &params,

304
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp

@ -0,0 +1,304 @@
#include "cpu_adam_arm.h"
void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4 = vdivq_f32(grad_4, loss_scale_vec);
}
float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i);
float32x4_t variance_4 =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i);
float32x4_t param_4 = simd_load_offset(_params, param_dtype, i);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4);
}
momentum_4 = vmulq_f32(momentum_4, betta1_4);
momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4);
variance_4 = vmulq_f32(variance_4, betta2_4);
grad_4 = vmulq_f32(grad_4, grad_4);
variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4);
grad_4 = vsqrtq_f32(variance_4);
grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt);
grad_4 = vdivq_f32(momentum_4, grad_4);
if (_weight_decay > 0 && _adamw_mode) {
param_4 = vfmaq_f32(param_4, param_4, weight_decay_4);
}
param_4 = vfmaq_f32(param_4, grad_4, step_size_4);
simd_store_offset(_params, param_dtype, param_4, i);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i);
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = scalar_load_offset(grads, grad_dtype, k);
if (loss_scale > 0) {
grad /= loss_scale;
}
float param = scalar_load_offset(_params, param_dtype, k);
float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k);
float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k);
if (_weight_decay > 0 && !_adamw_mode) {
grad = param * _weight_decay + grad;
}
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) {
param += w_decay * param;
}
param = grad * step_size + param;
scalar_store_offset(_params, param_dtype, param, k);
scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k);
scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k);
}
}
}
}
void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
float32x4_t grad_4[4];
float32x4_t momentum_4[4];
float32x4_t variance_4[4];
float32x4_t param_4[4];
#pragma unroll 4
for (int j = 0; j < 4; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_1(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
void *_exp_avg_sq, size_t _param_size,
at::ScalarType param_dtype,
at::ScalarType grad_dtype,
at::ScalarType exp_avg_dtype,
at::ScalarType exp_avg_sq_dtype, float loss_scale) {
size_t rounded_size = 0;
#if defined(__aarch64__)
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
#endif
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
#if defined(__aarch64__)
float32x4_t betta1_4 = simd_set(_betta1);
float32x4_t betta2_4 = simd_set(_betta2);
float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
float32x4_t bias2_sqrt = simd_set(_bias_correction2);
float32x4_t eps_4 = simd_set(_eps);
float32x4_t step_size_4 = simd_set(step_size);
float32x4_t weight_decay_4;
if (_weight_decay > 0) {
weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
}
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
float32x4_t grad_4[8];
float32x4_t momentum_4[8];
float32x4_t variance_4[8];
float32x4_t param_4[8];
#pragma unroll 4
for (int j = 0; j < 8; j++) {
grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
if (loss_scale > 0) {
float32x4_t loss_scale_vec = simd_set(loss_scale);
grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
}
momentum_4[j] =
simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
variance_4[j] =
simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
}
momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
grad_4[j] = vsqrtq_f32(variance_4[j]);
grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
if (_weight_decay > 0 && _adamw_mode) {
param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
}
param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
i + SIMD_WIDTH * j);
simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
i + SIMD_WIDTH * j);
}
}
}
#endif
if (_param_size > rounded_size) {
Step_4(scalar_seek_offset(_params, param_dtype, rounded_size),
scalar_seek_offset(grads, grad_dtype, rounded_size),
scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
(_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
exp_avg_sq_dtype, loss_scale);
}
}
void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay,
bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
this->IncrementStep(step, beta1, beta2);
this->update_state(lr, epsilon, weight_decay, bias_correction);
this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(),
exp_avg_sq_c.data_ptr(), params_c.numel(),
params_c.scalar_type(), grads_c.scalar_type(),
exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale);
}
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<AdamOptimizer>(m, "CPUAdamOptimizer")
.def(py::init<float, float, float, float, float, bool>())
.def("step", &AdamOptimizer::step);
}

201
colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h

@ -0,0 +1,201 @@
#pragma once
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cmath>
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#define TILE (128 * 1024 * 1024)
#if defined(__aarch64__)
#include <arm_neon.h>
#define SIMD_WIDTH 4
inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<const float32_t *>(ptr);
return vld1q_f32(ptr_f + offset);
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<const float16_t *>(ptr);
return vcvt_f32_f16(vld1_f16(ptr_h + offset));
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<const bfloat16_t *>(ptr);
// return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
return simd_load_offset(ptr, dtype, 0);
}
inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float: {
auto ptr_f = reinterpret_cast<float32_t *>(ptr);
vst1q_f32(ptr_f + offset, data);
break;
}
case at::ScalarType::Half: {
auto ptr_h = reinterpret_cast<float16_t *>(ptr);
vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
break;
}
// case at::ScalarType::BFloat16: {
// auto ptr_b = reinterpret_cast<bfloat16_t *>(ptr);
// vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
// break;
// }
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
return simd_store_offset(ptr, dtype, data, 0);
}
inline float32x4_t simd_set(float value) {
auto val = static_cast<float32_t>(value);
return vdupq_n_f32(val);
}
#endif
inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return *(reinterpret_cast<const float *>(ptr) + offset);
case at::ScalarType::Half:
return static_cast<float>(
*(reinterpret_cast<const at::Half *>(ptr) + offset));
// case at::ScalarType::BFloat16:
// return static_cast<float>(
// *(reinterpret_cast<const at::BFloat16 *>(ptr) + offset));
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
*(reinterpret_cast<float *>(ptr) + offset) = data;
break;
case at::ScalarType::Half:
*(reinterpret_cast<at::Half *>(ptr) + offset) = data;
break;
// case at::ScalarType::BFloat16:
// *(reinterpret_cast<at::BFloat16 *>(ptr) + offset) = data;
break;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
size_t offset) {
switch (dtype) {
case at::ScalarType::Float:
return reinterpret_cast<float *>(ptr) + offset;
case at::ScalarType::Half:
return reinterpret_cast<at::Half *>(ptr) + offset;
// case at::ScalarType::BFloat16:
// return reinterpret_cast<at::BFloat16 *>(ptr) + offset;
default:
AT_ERROR("Unsupported dtype");
break;
}
}
#define STEP(SPAN) \
void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
void *_exp_avg_sq, size_t _param_size, \
at::ScalarType param_dtype, at::ScalarType grad_dtype, \
at::ScalarType exp_avg_dtype, \
at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
class AdamOptimizer {
private:
float _alpha;
float _betta1;
float _betta2;
float _eps;
float _weight_decay;
float _betta1_t;
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
bool _adamw_mode;
public:
AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {}
~AdamOptimizer() {}
STEP(1)
STEP(4)
STEP(8)
inline void IncrementStep(size_t step, float beta1, float beta2) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
}
inline void update_state(float lr, float epsilon, float weight_decay,
bool bias_correction) {
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor &params,
torch::Tensor &grads, torch::Tensor &exp_avg,
torch::Tensor &exp_avg_sq, float loss_scale);
};

2
colossalai/kernel/cuda_native/mha/utils.py

@ -5,7 +5,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
class Unpad(torch.autograd.Function): class Unpad(torch.autograd.Function):

2
colossalai/legacy/engine/schedule/_pipeline_schedule.py

@ -12,7 +12,7 @@ from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule

2
colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py

@ -9,7 +9,7 @@ import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine from colossalai.legacy.engine import Engine
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._pipeline_schedule import PipelineSchedule from ._pipeline_schedule import PipelineSchedule

2
colossalai/legacy/nn/layer/parallel_1d/layers.py

@ -22,7 +22,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule from ..colossalai_layer._utils import ColossalaiModule

2
colossalai/legacy/nn/layer/parallel_2d/layers.py

@ -18,7 +18,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple

2
colossalai/legacy/nn/layer/parallel_2p5d/layers.py

@ -19,7 +19,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple

2
colossalai/legacy/nn/layer/parallel_3d/layers.py

@ -27,7 +27,7 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import ( from ._operation import (

2
colossalai/legacy/nn/layer/vanilla/layers.py

@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
from colossalai.legacy.context import seed from colossalai.legacy.context import seed
from colossalai.legacy.registry import LAYERS from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ..utils import to_2tuple from ..utils import to_2tuple

2
colossalai/legacy/zero/gemini/stateful_tensor_mgr.py

@ -3,7 +3,7 @@ import types
from time import time from time import time
from typing import List from typing import List
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from .stateful_tensor import StatefulTensor, TensorState from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy from .tensor_placement_policy import TensorPlacementPolicy

5
colossalai/nn/optimizer/cpu_adam.py

@ -1,9 +1,10 @@
import math import math
import platform
from typing import Optional from typing import Optional
import torch import torch
from colossalai.kernel.op_builder import CPUAdamBuilder from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder
from .nvme_optimizer import NVMeOptimizer from .nvme_optimizer import NVMeOptimizer
@ -77,7 +78,7 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode self.adamw_mode = adamw_mode
cpu_adam = CPUAdamBuilder().load() cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load()
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification # if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)

15
colossalai/nn/optimizer/hybrid_adam.py

@ -84,9 +84,10 @@ class HybridAdam(CPUAdam):
nvme_offload_fraction, nvme_offload_fraction,
nvme_offload_dir, nvme_offload_dir,
) )
fused_optim = FusedOptimBuilder().load() if torch.cuda.is_available():
self.gpu_adam_op = fused_optim.multi_tensor_adam fused_optim = FusedOptimBuilder().load()
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@torch.no_grad() @torch.no_grad()
def step(self, closure=None, div_scale: float = -1): def step(self, closure=None, div_scale: float = -1):
@ -118,11 +119,11 @@ class HybridAdam(CPUAdam):
group_step = state["step"] group_step = state["step"]
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
if target_device.type == "cpu": if target_device.type == "cpu" or target_device.type == "npu":
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu" assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq") self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16: if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu":
# cpu adam kernel does not support bf16 now # cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]

2
colossalai/pipeline/schedule/generate.py

@ -10,7 +10,7 @@ from torch.utils._pytree import tree_map
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule from .base import PipelineSchedule

2
colossalai/pipeline/schedule/interleaved_pp.py

@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule from .base import PipelineSchedule

2
colossalai/pipeline/schedule/one_f_one_b.py

@ -9,7 +9,7 @@ from torch.utils._pytree import tree_map
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from ._utils import ( from ._utils import (
detach, detach,

58
colossalai/shardformer/layer/normalization.py

@ -2,16 +2,19 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch.nn as nn import torch.nn as nn
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from ._operation import hook_paramter_in_backward
from ._operation import hook_paramter_in_backward
from .utils import SeqParallelUtils from .utils import SeqParallelUtils
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
try: try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm from apex.contrib.layer_norm.layer_norm import FastLayerNorm
EnableFastLayerNorm = True EnableFastLayerNorm = True
except ImportError: except ImportError:
EnableFastLayerNorm = False EnableFastLayerNorm = False
@ -19,10 +22,27 @@ except ImportError:
try: try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
return output
except ImportError: except ImportError:
warnings.warn( warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
)
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1024,
@ -52,6 +72,7 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
] ]
if EnableFastLayerNorm: if EnableFastLayerNorm:
class FastLayerNormWithHook(FastLayerNorm): class FastLayerNormWithHook(FastLayerNorm):
def __init__(self, hidden_size, eps=0.00001): def __init__(self, hidden_size, eps=0.00001):
super().__init__(hidden_size, eps) super().__init__(hidden_size, eps)
@ -60,25 +81,7 @@ if EnableFastLayerNorm:
output = super().forward(input) output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias) output = hook_paramter_in_backward(output, self.weight, self.bias)
return output return output
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
return output
class BaseLayerNorm(ABC): class BaseLayerNorm(ABC):
@abstractmethod @abstractmethod
@ -244,12 +247,13 @@ class FusedRMSNorm(BaseLayerNorm):
""" """
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
""" """
def __init__(self) -> None: def __init__(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. " "FusedRMSNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex." "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
) )
@staticmethod @staticmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r""" r"""
@ -264,7 +268,7 @@ class FusedRMSNorm(BaseLayerNorm):
nn.Module: FusedRMSNorm module. nn.Module: FusedRMSNorm module.
""" """
try: try:
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm pass
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel" "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
@ -282,7 +286,9 @@ class FusedRMSNorm(BaseLayerNorm):
eps = module.eps eps = module.eps
elementwise_affine = module.elementwise_affine elementwise_affine = module.elementwise_affine
rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
)
rmsnorm.weight = module.weight rmsnorm.weight = module.weight

3
colossalai/utils/__init__.py

@ -7,7 +7,7 @@ from .common import (
is_ddp_ignored, is_ddp_ignored,
set_seed, set_seed,
) )
from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
from .tensor_detector import TensorDetector from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer from .timer import MultiTimer, Timer
@ -29,4 +29,5 @@ __all__ = [
"set_seed", "set_seed",
"is_ddp_ignored", "is_ddp_ignored",
"set_device", "set_device",
"IS_NPU_AVAILABLE",
] ]

56
colossalai/utils/cuda.py

@ -1,56 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
import torch
import torch.distributed as dist
def set_to_cuda(models):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
else:
return torch.device("cpu")
def synchronize():
"""Similar to cuda.synchronize().
Waits for all kernels in all streams on a CUDA device to complete.
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
def empty_cache():
"""Similar to cuda.empty_cache()
Releases all unoccupied cached memory currently held by the caching allocator.
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % torch.cuda.device_count()
torch.cuda.set_device(index)

207
colossalai/utils/device.py

@ -0,0 +1,207 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
IS_NPU_AVAILABLE: bool = False
try:
import torch_npu # noqa
IS_NPU_AVAILABLE = torch.npu.is_available()
except ImportError:
pass
def set_to_cuda(models):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
ret.append(model.to(get_current_device()))
return ret
elif isinstance(models, list):
return models[0].to(get_current_device())
else:
return models.to(get_current_device())
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
elif IS_NPU_AVAILABLE:
return torch.device(f"npu:{torch.npu.current_device()}")
else:
return torch.device("cpu")
def _dispatch_device_func(fn_name: str, *args, **kwargs):
if torch.cuda.is_available():
return getattr(torch.cuda, fn_name)(*args, **kwargs)
elif IS_NPU_AVAILABLE:
return getattr(torch.npu, fn_name)(*args, **kwargs)
else:
raise RuntimeError("No device available")
# device semantics
def can_device_access_peer(device, peer_device) -> bool:
return _dispatch_device_func("can_device_access_peer", device, peer_device)
def current_device() -> int:
return _dispatch_device_func("current_device")
def current_stream(device=None):
return _dispatch_device_func("current_stream", device)
def default_stream(device=None):
return _dispatch_device_func("default_stream", device)
def device_count() -> int:
return _dispatch_device_func("device_count")
def get_device_capability(device=None) -> Tuple[int, int]:
return _dispatch_device_func("get_device_capability", device)
def get_device_name(device=None) -> str:
return _dispatch_device_func("get_device_name", device)
def get_device_properties(device):
return _dispatch_device_func("get_device_properties", device)
def set_device(index: Optional[int] = None) -> None:
if index is None:
index = dist.get_rank() % device_count()
_dispatch_device_func("set_device", index)
def set_stream(stream_):
return _dispatch_device_func("set_stream", stream_)
def stream(stream_):
return _dispatch_device_func("stream", stream_)
def synchronize():
return _dispatch_device_func("synchronize")
def utilization(device=None) -> int:
return _dispatch_device_func("utilization", device)
# random number generator
def get_rng_state(device="cuda") -> torch.Tensor:
return _dispatch_device_func("get_rng_state", device)
def get_rng_state_all() -> List[torch.Tensor]:
return _dispatch_device_func("get_rng_state_all")
def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
return _dispatch_device_func("set_rng_state", new_state, device)
def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
return _dispatch_device_func("set_rng_state_all", new_states)
def manual_seed(seed: int) -> None:
return _dispatch_device_func("manual_seed", seed)
def manual_seed_all(seed: int) -> None:
return _dispatch_device_func("manual_seed_all", seed)
def seed() -> None:
return _dispatch_device_func("seed")
def seed_all() -> None:
return _dispatch_device_func("seed_all")
def initial_seed() -> int:
return _dispatch_device_func("initial_seed")
# streams and events
def Stream(device=None, priority=0, **kwargs):
return _dispatch_device_func("Stream", device, priority, **kwargs)
def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
# memory management
def empty_cache() -> None:
return _dispatch_device_func("empty_cache")
def memory_stats(device=None) -> Dict[str, Any]:
return _dispatch_device_func("memory_stats", device)
def memory_summary(device=None, abbreviated=False) -> str:
return _dispatch_device_func("memory_summary", device, abbreviated)
def memory_snapshot():
return _dispatch_device_func("memory_snapshot")
def memory_allocated(device=None) -> int:
return _dispatch_device_func("memory_allocated", device)
def max_memory_allocated(device=None) -> int:
return _dispatch_device_func("max_memory_allocated", device)
def reset_max_memory_allocated(device=None) -> None:
return _dispatch_device_func("reset_max_memory_allocated", device)
def memory_reserved(device=None) -> int:
return _dispatch_device_func("memory_reserved", device)
def max_memory_reserved(device=None) -> int:
return _dispatch_device_func("max_memory_reserved", device)
def set_per_process_memory_fraction(fraction: float, device=None) -> None:
return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
def reset_peak_memory_stats(device=None) -> None:
return _dispatch_device_func("reset_peak_memory_stats", device)

2
colossalai/utils/timer.py

@ -3,7 +3,7 @@
import time import time
from typing import Tuple from typing import Tuple
from .cuda import synchronize from .device import synchronize
class Timer: class Timer:

15
colossalai/zero/gemini/chunk/chunk.py

@ -7,6 +7,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.device import IS_NPU_AVAILABLE
class TensorState(Enum): class TensorState(Enum):
@ -172,7 +173,7 @@ class Chunk:
if self.chunk_temp is not None: if self.chunk_temp is not None:
# this chunk is not closed # this chunk is not closed
if self.chunk_temp.device.type == "cuda": if self.chunk_temp.device.type == "cuda" or self.chunk_temp.device.type == "npu":
cuda_memory += self.chunk_mem cuda_memory += self.chunk_mem
else: else:
cpu_memory += self.chunk_mem cpu_memory += self.chunk_mem
@ -191,10 +192,8 @@ class Chunk:
if self.chunk_temp is not None: if self.chunk_temp is not None:
return self.chunk_temp.device.type return self.chunk_temp.device.type
else: else:
if self.is_gathered: if self.is_gathered or self.cuda_shard is not None:
return "cuda" return "npu" if IS_NPU_AVAILABLE else "cuda"
elif self.cuda_shard is not None:
return "cuda"
else: else:
return "cpu" return "cpu"
@ -329,12 +328,12 @@ class Chunk:
# when the current chunk is not synchronized with the optimizer # when the current chunk is not synchronized with the optimizer
# just use another way for the movement # just use another way for the movement
if not self.optim_sync_flag: if not self.optim_sync_flag:
assert device.type == "cuda", "each chunk should first be moved to CUDA" assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
self.__paired_shard_move() self.__paired_shard_move()
self.optim_sync_flag = True self.optim_sync_flag = True
return return
if device.type == "cuda": if device.type == "cuda" or device.type == "npu":
assert device == get_current_device(), "can't move chunk to another device" assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard: if self.cuda_shard:
@ -484,7 +483,7 @@ class Chunk:
assert friend_chunk.is_gathered is True assert friend_chunk.is_gathered is True
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk) self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True self.optim_sync_flag = True
elif friend_chunk.device_type == "cuda" and self.device_type == "cuda": elif friend_chunk.device_type in ("cuda", "npu") and self.device_type in ("cuda", "npu"):
self.cuda_shard.copy_(friend_chunk.cuda_shard) self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True self.optim_sync_flag = True
self.cpu_vis_flag = False self.cpu_vis_flag = False

5
colossalai/zero/gemini/chunk/manager.py

@ -206,7 +206,10 @@ class ChunkManager:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state. tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
""" """
assert tensor not in self.tensor_chunk_map assert tensor not in self.tensor_chunk_map
self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size() device_type = tensor.device.type
if device_type == "npu":
device_type = "cuda"
self.total_mem[device_type] += tensor.numel() * tensor.element_size()
def __repr__(self) -> str: def __repr__(self) -> str:
msg = [ msg = [

66
colossalai/zero/gemini/gemini_ddp.py

@ -10,32 +10,30 @@ import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import ModelWrapper from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from colossalai.checkpoint_io.utils import gather_distributed_param
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda
from colossalai.tensor.d_tensor import ( from colossalai.tensor.d_tensor import (
distribute_tensor, distribute_tensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh, get_device_mesh,
get_global_shape,
get_sharding_spec, get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
get_global_shape,
init_as_dtensor
) )
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
@ -162,7 +160,7 @@ class GeminiDDP(ModelWrapper):
self._init_chunks( self._init_chunks(
param_order=param_order, param_order=param_order,
strict_ddp_mode=strict_ddp_mode, strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != "cuda", cpu_offload=not (self.gemini_manager.policy_name == "static" and offload_param_frac == 0),
pin_memory=pin_memory, pin_memory=pin_memory,
) )
super().__init__(module) super().__init__(module)
@ -453,12 +451,13 @@ class GeminiDDP(ModelWrapper):
global_shape = get_global_shape(tensor) global_shape = get_global_shape(tensor)
device_mesh = get_device_mesh(tensor) device_mesh = get_device_mesh(tensor)
shard_spec = get_sharding_spec(tensor) shard_spec = get_sharding_spec(tensor)
record_tensor = init_as_dtensor(record_tensor, record_tensor = init_as_dtensor(
device_mesh=device_mesh, record_tensor, device_mesh=device_mesh, sharding_spec=shard_spec, global_shape=global_shape
sharding_spec=shard_spec, )
global_shape = global_shape)
elif is_customized_distributed_tensor(tensor): elif is_customized_distributed_tensor(tensor):
init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn) init_tensor_as_customization_distributed(
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
)
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
assert tensor not in chunk_to_save_data assert tensor not in chunk_to_save_data
@ -634,7 +633,15 @@ class GeminiDDP(ModelWrapper):
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None} local_state = {k: v for k, v in local_name_params if v is not None}
def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None): def load(
param_name,
dest_tensor,
copy_func,
source_device_mesh=None,
source_sharding_spec=None,
shard_fn=None,
gather_fn=None,
):
state_key = prefix + param_name state_key = prefix + param_name
if state_key in state_dict: if state_key in state_dict:
input_param = state_dict[state_key] input_param = state_dict[state_key]
@ -642,7 +649,9 @@ class GeminiDDP(ModelWrapper):
if source_device_mesh is not None and source_sharding_spec is not None: if source_device_mesh is not None and source_sharding_spec is not None:
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
elif shard_fn is not None and gather_fn is not None: elif shard_fn is not None and gather_fn is not None:
input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn) input_param = distribute_tensor_with_customization(
input_param, shard_fn=shard_fn, gather_fn=gather_fn
)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
@ -687,7 +696,6 @@ class GeminiDDP(ModelWrapper):
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
for tensor, tensor_info in chunk.tensors_info.items(): for tensor, tensor_info in chunk.tensors_info.items():
source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None
if is_distributed_tensor(tensor): if is_distributed_tensor(tensor):
# shard the input param # shard the input param
@ -699,7 +707,15 @@ class GeminiDDP(ModelWrapper):
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn) load(
parameter_name,
tensor,
partial(load_parameter, parameter_slice),
source_device_mesh,
source_sharding_spec,
shard_fn,
gather_fn,
)
if chunk.is_gathered: if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk) chunk.cuda_global_chunk.copy_(temp_chunk)
@ -799,7 +815,7 @@ class GeminiDDP(ModelWrapper):
for buffer in self.module.buffers(): for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor): if isinstance(buffer, LazyTensor):
buffer.materialize() buffer.materialize()
buffer.data = buffer.cuda() buffer.data = buffer.to(get_current_device())
if torch.is_floating_point(buffer): if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision) buffer.data = buffer.to(self.mixed_precision)

6
colossalai/zero/gemini/gemini_mgr.py

@ -17,9 +17,7 @@ class GeminiManager:
https://arxiv.org/abs/2108.05818 https://arxiv.org/abs/2108.05818
Args: Args:
placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. placement_policy (str): Which device to place *held* tensors. It can be 'static' and 'auto'.
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training. Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance. chunk_manager (ChunkManager): A ``ChunkManager`` instance.
@ -121,7 +119,7 @@ class GeminiManager:
start = time() start = time()
cuda_demand = 0 cuda_demand = 0
for chunk in chunks: for chunk in chunks:
if chunk.device_type == "cuda": if chunk.device_type == "cuda" or chunk.device_type == "npu":
if chunk.is_gathered: if chunk.is_gathered:
pass pass
else: else:

53
colossalai/zero/gemini/gemini_optimizer.py

@ -7,31 +7,29 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from packaging.version import Version from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.distributed import ProcessGroup
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import ( from colossalai.tensor.d_tensor import (
distribute_tensor, distribute_tensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh, get_device_mesh,
get_sharding_spec, get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
get_global_shape,
init_as_dtensor
) )
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] __all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
@ -312,7 +310,7 @@ class GeminiOptimizer(OptimizerWrapper):
chunk16 = self.param_to_chunk16[fake_param] chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda": if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
continue continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem: if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
@ -326,7 +324,7 @@ class GeminiOptimizer(OptimizerWrapper):
for fake_param in group["params"]: for fake_param in group["params"]:
chunk16 = self.param_to_chunk16[fake_param] chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk chunk32 = chunk16.paired_chunk
if chunk32.device_type == "cuda": if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
state = self.optim.state[fake_param] state = self.optim.state[fake_param]
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
@ -479,15 +477,19 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
if is_dtensor: if is_dtensor:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
state_tensor = init_as_dtensor(state_tensor, state_tensor = init_as_dtensor(
device_mesh=device_mesh, state_tensor,
sharding_spec=shard_spec, device_mesh=device_mesh,
global_shape = global_shape) sharding_spec=shard_spec,
global_shape=global_shape,
)
elif is_customized_distributed: elif is_customized_distributed:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
collected_states[state_name] = state_tensor.reshape(global_shape) collected_states[state_name] = state_tensor.reshape(global_shape)
return collected_states return collected_states
@ -533,13 +535,14 @@ class GeminiOptimizer(OptimizerWrapper):
collected_states[state_name] = torch.reshape(state_tensor, param.shape) collected_states[state_name] = torch.reshape(state_tensor, param.shape)
if is_dtensor: if is_dtensor:
state_tensor = state_tensor.to(param.device) state_tensor = state_tensor.to(param.device)
state_tensor = init_as_dtensor(state_tensor, state_tensor = init_as_dtensor(
sharding_spec=shard_spec, state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
device_mesh=device_mesh, )
global_shape=global_shape)
elif is_customized_distributed: elif is_customized_distributed:
state_tensor = state_tensor.to(param.device) state_tensor = state_tensor.to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn) init_tensor_as_customization_distributed(
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
return collected_states return collected_states
@ -548,7 +551,7 @@ class GeminiOptimizer(OptimizerWrapper):
self, self,
param_id: int, param_id: int,
state_names: list, state_names: list,
device: torch.device = torch.device("cuda"), device: torch.device = get_current_device(),
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -705,7 +708,7 @@ class GeminiOptimizer(OptimizerWrapper):
ret_val = torch.zeros( ret_val = torch.zeros(
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
) )
if is_dtensor: if is_dtensor:
value = torch.reshape(value, global_shape) value = torch.reshape(value, global_shape)
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)

52
colossalai/zero/low_level/low_level_optim.py

@ -12,6 +12,7 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
import colossalai.utils.device as device_utils
from colossalai.amp.naive_amp.mixed_precision_mixin import ( from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin, BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin, FP16MixedPrecisionMixin,
@ -22,7 +23,7 @@ from colossalai.logging import get_dist_logger
from colossalai.tensor.moe_tensor.api import is_moe_tensor from colossalai.tensor.moe_tensor.api import is_moe_tensor
# from colossalai.tensor import ColoParameter, ProcessGroup # from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, ParameterStore from .bookkeeping import BucketStore, GradientStore, ParameterStore
@ -182,7 +183,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# intialize communication stream for # intialize communication stream for
# communication-compuation overlapping # communication-compuation overlapping
if self._overlap_communication: if self._overlap_communication:
self._comm_stream = torch.cuda.Stream() self._comm_stream = device_utils.Stream()
# reduction hook is only used if overlapping communication # reduction hook is only used if overlapping communication
# or stage 2 is used # or stage 2 is used
@ -216,7 +217,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return len(self._working_param_groups) return len(self._working_param_groups)
def _sanity_checks(self): def _sanity_checks(self):
assert torch.cuda.is_available(), "CUDA is required" assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
for param_group in self.optim.param_groups: for param_group in self.optim.param_groups:
group_params = param_group["params"] group_params = param_group["params"]
for param in group_params: for param in group_params:
@ -339,11 +340,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(moe_grad_list) > 0: if len(moe_grad_list) > 0:
moe_flat_grads.record_stream(stream) moe_flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing # waiting for ops in the default stream finishing
stream.wait_stream(torch.cuda.current_stream()) stream.wait_stream(device_utils.current_stream())
else: else:
stream = torch.cuda.current_stream() stream = device_utils.current_stream()
with torch.cuda.stream(stream): with device_utils.stream(stream):
group_id = self._bucket_store.current_group_id group_id = self._bucket_store.current_group_id
if self.moe_extra_dp_pg is None: if self.moe_extra_dp_pg is None:
@ -485,7 +486,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
torch.cuda.synchronize() device_utils.synchronize()
self.zero_grad() self.zero_grad()
@ -504,7 +505,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
torch.cuda.synchronize() device_utils.synchronize()
self.zero_grad() self.zero_grad()
@ -620,22 +621,25 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
release_param_grad(self._master_param_groups_of_current_rank[group_id]) release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank # update working partition updated by the current rank
device = get_current_device()
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"] master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param): for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx] working_param = real_working_params[group_id][idx]
if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param): if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
all_splited_param = [ all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self.moe_extra_dp_pg_size) for _ in range(self.moe_extra_dp_pg_size)
] ]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.moe_extra_dp_pg) dist.all_gather(
all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
)
else: else:
all_splited_param = [ all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._world_size) for _ in range(self._world_size)
] ]
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg) dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)) working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
@ -657,7 +661,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
norm_type = float(norm_type) norm_type = float(norm_type)
if norm_type == inf: if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients) total_norm = max(grad.data.abs().max() for grad in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg) dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item() total_norm = total_norm_cuda.item()
@ -668,7 +672,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
total_norm_exponentiated += grad_norm_exponentiated total_norm_exponentiated += grad_norm_exponentiated
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)]) total_norm_exponentiated_cuda = torch.tensor(
[float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
)
torch.distributed.all_reduce( torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
) )
@ -759,6 +765,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
Dict: the pytorch form state_dict Dict: the pytorch form state_dict
""" """
zero_state = dict() zero_state = dict()
device = get_current_device()
for param, state in self.optim.state.items(): for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state) zero_state[param] = copy.deepcopy(state)
for k, v in state.items(): for k, v in state.items():
@ -766,14 +773,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
working_param = self._param_store.master_to_working_param[id(param)] working_param = self._param_store.master_to_working_param[id(param)]
if self.moe_extra_dp_pg is not None and is_moe_tensor(v): if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
gather_tensor = [ gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
] ]
dist.all_gather(gather_tensor, v.cuda(), group=self.moe_extra_dp_pg) dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg)
else: else:
gather_tensor = [ gather_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
] ]
dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
param_state = ( param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
) )
@ -820,6 +827,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block = dict() ret_block = dict()
ret_block_size = 0 ret_block_size = 0
device = get_current_device()
local_states = self.optim.state_dict()["state"] local_states = self.optim.state_dict()["state"]
for param_idx, states in local_states.items(): for param_idx, states in local_states.items():
current_block_size = 0 current_block_size = 0
@ -836,14 +844,14 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
if self.moe_extra_dp_pg is not None and is_moe_tensor(v): if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
state_tensor = [ state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size) torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
] ]
dist.all_gather(state_tensor, v.cuda(), group=self.moe_extra_dp_pg) dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg)
else: else:
state_tensor = [ state_tensor = [
torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size) torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
] ]
dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
state_tensor = ( state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu() torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
) )

5
examples/language/llama2/benchmark.py

@ -13,6 +13,7 @@ from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM
import colossalai import colossalai
import colossalai.utils.device as device_utils
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
@ -194,7 +195,7 @@ def main():
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader) model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
coordinator.print_on_master( coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
) )
@ -220,7 +221,7 @@ def main():
performance_evaluator.on_step_end(**batch) performance_evaluator.on_step_end(**batch)
performance_evaluator.on_fit_end() performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__": if __name__ == "__main__":

8
examples/language/llama2/performance_evaluator.py

@ -5,7 +5,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
import colossalai.utils.device as device_utils
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils.device import get_current_device
def divide(x: float, y: float) -> float: def divide(x: float, y: float) -> float:
@ -20,7 +22,7 @@ def divide(x: float, y: float) -> float:
def all_reduce_mean(x: float, world_size: int) -> float: def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1: if world_size == 1:
return x return x
tensor = torch.tensor([x], device=torch.cuda.current_device()) tensor = torch.tensor([x], device=get_current_device())
dist.all_reduce(tensor) dist.all_reduce(tensor)
tensor = tensor / world_size tensor = tensor / world_size
return tensor.item() return tensor.item()
@ -84,13 +86,13 @@ class PerformanceEvaluator:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable: if self.disable:
return return
torch.cuda.synchronize() device_utils.synchronize()
self.timer.start() self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None: def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable: if self.disable:
return return
torch.cuda.synchronize() device_utils.synchronize()
self.timer.end() self.timer.end()
batch_size, seq_len = input_ids.shape batch_size, seq_len = input_ids.shape

2
op_builder/__init__.py

@ -1,3 +1,4 @@
from .arm_cpu_adam import ArmCPUAdamBuilder
from .cpu_adam import CPUAdamBuilder from .cpu_adam import CPUAdamBuilder
from .fused_optim import FusedOptimBuilder from .fused_optim import FusedOptimBuilder
from .layernorm import LayerNormBuilder from .layernorm import LayerNormBuilder
@ -29,4 +30,5 @@ __all__ = [
"MultiTensorLambBuilder", "MultiTensorLambBuilder",
"MultiTensorScaleBuilder", "MultiTensorScaleBuilder",
"MultiTensorL2NormBuilder", "MultiTensorL2NormBuilder",
"ArmCPUAdamBuilder",
] ]

34
op_builder/arm_cpu_adam.py

@ -0,0 +1,34 @@
from .builder import Builder
class ArmCPUAdamBuilder(Builder):
NAME = "arm_cpu_adam"
PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
ext_type = "cpu"
def __init__(self):
super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
# necessary 4 functions
def sources_files(self):
ret = [
self.csrc_abs_path("cpu_adam_arm.cpp"),
]
return ret
def include_dirs(self):
return [self.csrc_abs_path("includes")]
def cxx_flags(self):
extra_cxx_flags = [
"-std=c++14",
"-std=c++17",
"-g",
"-Wno-reorder",
"-fopenmp",
]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
return []

19
op_builder/builder.py

@ -7,7 +7,7 @@ import os
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional, Union
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0 from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
@ -21,6 +21,8 @@ class Builder(ABC):
prebuilt_import_path (str): the path where the extension is installed during pip install prebuilt_import_path (str): the path where the extension is installed during pip install
""" """
ext_type: str = "cuda"
def __init__(self, name: str, prebuilt_import_path: str): def __init__(self, name: str, prebuilt_import_path: str):
self.name = name self.name = name
self.prebuilt_import_path = prebuilt_import_path self.prebuilt_import_path = prebuilt_import_path
@ -165,7 +167,8 @@ class Builder(ABC):
) )
except ImportError: except ImportError:
# check environment # check environment
self.check_runtime_build_environment() if self.ext_type == "cuda":
self.check_runtime_build_environment()
# time the kernel compilation # time the kernel compilation
start_build = time.time() start_build = time.time()
@ -208,11 +211,19 @@ class Builder(ABC):
return op_module return op_module
def builder(self) -> "CUDAExtension": def builder(self) -> Union["CUDAExtension", "CppExtension"]:
""" """
get a CUDAExtension instance used for setup.py get a CUDAExtension instance used for setup.py
""" """
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension
if self.ext_type == "cpp":
return CppExtension(
name=self.prebuilt_import_path,
sources=self.strip_empty_entries(self.sources_files()),
include_dirs=self.strip_empty_entries(self.include_dirs()),
extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
)
return CUDAExtension( return CUDAExtension(
name=self.prebuilt_import_path, name=self.prebuilt_import_path,

14
tests/test_booster/test_plugin/test_low_level_zero_plugin.py

@ -2,11 +2,14 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.optim import Adam
import colossalai import colossalai
import colossalai.utils.device as device_utils
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.nn.optimizer import HybridAdam
# from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
@ -19,16 +22,17 @@ _STUCK_MODELS = ["transformers_albert_for_multiple_choice"]
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
device = device_utils.get_current_device()
try: try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5) plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = model_fn() model = model_fn()
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = Adam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
data = data_gen_fn() data = data_gen_fn()
data = { data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
} }
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
@ -65,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
continue continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn) err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
torch.cuda.empty_cache() device_utils.empty_cache()
if err is None: if err is None:
passed_models.append(name) passed_models.append(name)
@ -89,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_low_level_zero_plugin(early_stop: bool = True): def test_low_level_zero_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop) spawn(run_dist, 2, early_stop=early_stop)
if __name__ == "__main__": if __name__ == "__main__":

2
tests/test_legacy/test_utils/test_memory.py

@ -3,7 +3,7 @@ import pytest
import colossalai import colossalai
from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.testing import spawn from colossalai.testing import spawn
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():

2
tests/test_zero/test_gemini/test_fwd_bwd.py

@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd

2
tests/test_zero/test_gemini/test_grad_accum.py

@ -9,7 +9,7 @@ import colossalai
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd from tests.kit.model_zoo import model_zoo, run_fwd

2
tests/test_zero/test_gemini/test_inference.py

@ -11,7 +11,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd

2
tests/test_zero/test_gemini/test_optim.py

@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd

19
tests/test_zero/test_low_level/test_grad_acc.py

@ -9,7 +9,7 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.testing import spawn from colossalai.testing import spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context from colossalai.utils import conditional_context, get_current_device
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
@ -28,9 +28,9 @@ class MlpModel(nn.Module):
def exam_zero_1_2_grad_acc(): def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2009) seed_all(2009)
device = get_current_device()
# create model # create model
zero1_model = MlpModel().cuda() zero1_model = MlpModel().to(device)
zero2_model = copy.deepcopy(zero1_model) zero2_model = copy.deepcopy(zero1_model)
# create optimizer # create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
@ -43,8 +43,8 @@ def exam_zero_1_2_grad_acc():
) )
# create data # create data
seed_all(2021 + local_rank) seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda() input_data1 = torch.randn(32, 128, device=device)
input_data2 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(number, cur_data, check_flag): def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward # zero-dp forward
@ -71,14 +71,15 @@ def exam_zero_1_2_grad_acc():
def exam_zero_1_grad_acc(sync): def exam_zero_1_grad_acc(sync):
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2008) seed_all(2008)
device = get_current_device()
# create models # create models
zero_model = MlpModel() zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model) torch_model = copy.deepcopy(zero_model)
seed_all(2008) seed_all(2008)
zero_model = zero_model.cuda() zero_model = zero_model.to(device)
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) torch_model = DDP(torch_model.to(device), bucket_cap_mb=0)
# create optimizer # create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
@ -94,8 +95,8 @@ def exam_zero_1_grad_acc(sync):
# create data # create data
seed_all(2022 + local_rank) seed_all(2022 + local_rank)
input_data1 = torch.randn(32, 128).cuda() input_data1 = torch.randn(32, 128, device=device)
input_data2 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(no_sync, cur_data, check_flag): def fwd_bwd_func(no_sync, cur_data, check_flag):
# zero1 fwd and bwd # zero1 fwd and bwd

Loading…
Cancel
Save