mirror of https://github.com/hpcaitech/ColossalAI
fix bugs in CPU adam (#633)
* add cpu adam counter for all cpu adam * fixed updating error in adam kernelpull/657/head
parent
1e2557e801
commit
b31daed4cf
|
@ -493,7 +493,7 @@ int adam_step(int optimizer_id,
|
||||||
grads_ptr,
|
grads_ptr,
|
||||||
exp_avg_ptr,
|
exp_avg_ptr,
|
||||||
exp_avg_sq_ptr,
|
exp_avg_sq_ptr,
|
||||||
params_c.size(0),
|
params_c.numel(),
|
||||||
(params.options().dtype() == at::kHalf),
|
(params.options().dtype() == at::kHalf),
|
||||||
(grads.options().dtype() == at::kHalf),
|
(grads.options().dtype() == at::kHalf),
|
||||||
loss_scale);
|
loss_scale);
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from .utils import CPU_ADAM_CNT
|
||||||
from .colossalai_optimizer import ColossalaiOptimizer
|
from .colossalai_optimizer import ColossalaiOptimizer
|
||||||
from .fused_adam import FusedAdam
|
from .fused_adam import FusedAdam
|
||||||
from .fused_lamb import FusedLAMB
|
from .fused_lamb import FusedLAMB
|
||||||
|
@ -7,4 +8,5 @@ from .lars import Lars
|
||||||
from .cpu_adam import CPUAdam
|
from .cpu_adam import CPUAdam
|
||||||
from .hybrid_adam import HybridAdam
|
from .hybrid_adam import HybridAdam
|
||||||
|
|
||||||
__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam']
|
__all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD',
|
||||||
|
'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT']
|
||||||
|
|
|
@ -2,6 +2,7 @@ import math
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.registry import OPTIMIZERS
|
from colossalai.registry import OPTIMIZERS
|
||||||
|
from colossalai.nn.optimizer import CPU_ADAM_CNT
|
||||||
|
|
||||||
|
|
||||||
@OPTIMIZERS.register_module
|
@OPTIMIZERS.register_module
|
||||||
|
@ -51,7 +52,6 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
https://openreview.net/forum?id=ryQu7f-RZ
|
https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
"""
|
"""
|
||||||
|
|
||||||
optimizer_id = 0
|
|
||||||
# Number of fp32 shards for per parameter
|
# Number of fp32 shards for per parameter
|
||||||
# Param weight, grad, momentum and variance
|
# Param weight, grad, momentum and variance
|
||||||
num_fp32_shards_per_param = 4
|
num_fp32_shards_per_param = 4
|
||||||
|
@ -68,8 +68,7 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
|
|
||||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||||
super(CPUAdam, self).__init__(model_params, default_args)
|
super(CPUAdam, self).__init__(model_params, default_args)
|
||||||
self.opt_id = CPUAdam.optimizer_id
|
self.opt_id = CPU_ADAM_CNT()
|
||||||
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1
|
|
||||||
self.adamw_mode = adamw_mode
|
self.adamw_mode = adamw_mode
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import cpu_adam
|
||||||
|
@ -152,8 +151,8 @@ class CPUAdam(torch.optim.Optimizer):
|
||||||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||||
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||||
|
|
||||||
bias_correction1 = 1 - beta1**state['step']
|
bias_correction1 = 1 - beta1 ** state['step']
|
||||||
bias_correction2 = 1 - beta2**state['step']
|
bias_correction2 = 1 - beta2 ** state['step']
|
||||||
|
|
||||||
# adam on cuda
|
# adam on cuda
|
||||||
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
|
|
||||||
from colossalai.utils import multi_tensor_applier
|
from colossalai.utils import multi_tensor_applier
|
||||||
from colossalai.registry import OPTIMIZERS
|
from colossalai.registry import OPTIMIZERS
|
||||||
|
from colossalai.nn.optimizer import CPU_ADAM_CNT
|
||||||
|
|
||||||
|
|
||||||
@OPTIMIZERS.register_module
|
@OPTIMIZERS.register_module
|
||||||
|
@ -50,7 +51,6 @@ class HybridAdam(torch.optim.Optimizer):
|
||||||
https://openreview.net/forum?id=ryQu7f-RZ
|
https://openreview.net/forum?id=ryQu7f-RZ
|
||||||
"""
|
"""
|
||||||
|
|
||||||
optimizer_id = 0
|
|
||||||
# Number of fp32 shards for per parameter
|
# Number of fp32 shards for per parameter
|
||||||
# Param weight, grad, momentum and variance
|
# Param weight, grad, momentum and variance
|
||||||
num_fp32_shards_per_param = 4
|
num_fp32_shards_per_param = 4
|
||||||
|
@ -67,8 +67,7 @@ class HybridAdam(torch.optim.Optimizer):
|
||||||
|
|
||||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||||
super(HybridAdam, self).__init__(model_params, default_args)
|
super(HybridAdam, self).__init__(model_params, default_args)
|
||||||
self.opt_id = HybridAdam.optimizer_id
|
self.opt_id = CPU_ADAM_CNT()
|
||||||
HybridAdam.optimizer_id = HybridAdam.optimizer_id + 1
|
|
||||||
self.adamw_mode = adamw_mode
|
self.adamw_mode = adamw_mode
|
||||||
try:
|
try:
|
||||||
import cpu_adam
|
import cpu_adam
|
||||||
|
|
|
@ -0,0 +1,14 @@
|
||||||
|
class CpuAdamCounter(object):
|
||||||
|
"""Used to record the total number of CPU Adam.
|
||||||
|
We must use it to avoid hybrid cpu adam and cpu adam using the same id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.number = 0
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
self.number += 1
|
||||||
|
return self.number - 1
|
||||||
|
|
||||||
|
|
||||||
|
CPU_ADAM_CNT = CpuAdamCounter()
|
|
@ -1,7 +1,6 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.utils.cuda import get_current_device
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
@ -51,11 +50,10 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler):
|
||||||
@parameterize("use_cpuadam", [True, False])
|
@parameterize("use_cpuadam", [True, False])
|
||||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
|
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
|
||||||
MOE_CONTEXT.reset_loss()
|
|
||||||
shard_strategy = shard_strategy_class()
|
shard_strategy = shard_strategy_class()
|
||||||
if use_cpuadam and cpu_offload is False:
|
if use_cpuadam and cpu_offload is False:
|
||||||
return
|
return
|
||||||
|
MOE_CONTEXT.reset_loss()
|
||||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||||
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue