Browse Source

fix bugs in CPU adam (#633)

* add cpu adam counter for all cpu adam

* fixed updating error in adam kernel
pull/657/head
HELSON 3 years ago committed by GitHub
parent
commit
b31daed4cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
  2. 4
      colossalai/nn/optimizer/__init__.py
  3. 5
      colossalai/nn/optimizer/cpu_adam.py
  4. 5
      colossalai/nn/optimizer/hybrid_adam.py
  5. 14
      colossalai/nn/optimizer/utils.py
  6. 4
      tests/test_moe/test_moe_zero_optim.py

2
colossalai/kernel/cuda_native/csrc/cpu_adam.cpp

@ -493,7 +493,7 @@ int adam_step(int optimizer_id,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.size(0),
params_c.numel(),
(params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
loss_scale);

4
colossalai/nn/optimizer/__init__.py

@ -1,3 +1,4 @@
from .utils import CPU_ADAM_CNT
from .colossalai_optimizer import ColossalaiOptimizer
from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB
@ -7,4 +8,5 @@ from .lars import Lars
from .cpu_adam import CPUAdam
from .hybrid_adam import HybridAdam
__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam']
__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD',
'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT']

5
colossalai/nn/optimizer/cpu_adam.py

@ -2,6 +2,7 @@ import math
import torch
from colossalai.registry import OPTIMIZERS
from colossalai.nn.optimizer import CPU_ADAM_CNT
@OPTIMIZERS.register_module
@ -51,7 +52,6 @@ class CPUAdam(torch.optim.Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ
"""
optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
@ -68,8 +68,7 @@ class CPUAdam(torch.optim.Optimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args)
self.opt_id = CPUAdam.optimizer_id
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1
self.opt_id = CPU_ADAM_CNT()
self.adamw_mode = adamw_mode
try:
import cpu_adam

5
colossalai/nn/optimizer/hybrid_adam.py

@ -2,6 +2,7 @@ import torch
from colossalai.utils import multi_tensor_applier
from colossalai.registry import OPTIMIZERS
from colossalai.nn.optimizer import CPU_ADAM_CNT
@OPTIMIZERS.register_module
@ -50,7 +51,6 @@ class HybridAdam(torch.optim.Optimizer):
https://openreview.net/forum?id=ryQu7f-RZ
"""
optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
num_fp32_shards_per_param = 4
@ -67,8 +67,7 @@ class HybridAdam(torch.optim.Optimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args)
self.opt_id = HybridAdam.optimizer_id
HybridAdam.optimizer_id = HybridAdam.optimizer_id + 1
self.opt_id = CPU_ADAM_CNT()
self.adamw_mode = adamw_mode
try:
import cpu_adam

14
colossalai/nn/optimizer/utils.py

@ -0,0 +1,14 @@
class CpuAdamCounter(object):
"""Used to record the total number of CPU Adam.
We must use it to avoid hybrid cpu adam and cpu adam using the same id.
"""
def __init__(self):
self.number = 0
def __call__(self):
self.number += 1
return self.number - 1
CPU_ADAM_CNT = CpuAdamCounter()

4
tests/test_moe/test_moe_zero_optim.py

@ -1,7 +1,6 @@
from functools import partial
import colossalai
from colossalai.utils.cuda import get_current_device
import pytest
import torch
import torch.multiprocessing as mp
@ -51,11 +50,10 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler):
@parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
MOE_CONTEXT.reset_loss()
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
MOE_CONTEXT.reset_loss()
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, optimizer_class, criterion = get_components_func()

Loading…
Cancel
Save