diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 0b078dc88..f3ef91d0e 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -20,6 +20,7 @@ class CPUAdam(torch.optim.Optimizer): The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance. The sharded param of model_params can resident on both CPU and CUDA. """ + default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args) self.opt_id = CPUAdam.optimizer_id @@ -34,7 +35,8 @@ class CPUAdam(torch.optim.Optimizer): self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log) def __del__(self): - self.cpu_adam_op.destroy_adam(self.opt_id) + if self.cpu_adam_op: + self.cpu_adam_op.destroy_adam(self.opt_id) def torch_adam_update(self, data, @@ -72,7 +74,6 @@ class CPUAdam(torch.optim.Optimizer): @torch.no_grad() def step(self, closure=None): - loss = None if closure is not None: with torch.enable_grad(): diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py index d5ba72a2e..c118f7710 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -2,9 +2,10 @@ from typing import List import torch import torch.distributed as dist +from torch._utils import _flatten_dense_tensors as flatten + from colossalai.utils import get_current_device from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from torch._utils import _flatten_dense_tensors as flatten from .tensor_shard_strategy import TensorShardStrategy diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 08ac39e7d..4e4bdaabb 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -2,6 +2,7 @@ from typing import List, Optional import torch import torch.distributed as dist + from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index a7a24ef64..2347dd125 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -1,9 +1,14 @@ from enum import Enum -from typing import Callable, Dict, Optional, Union +from typing import Dict, Optional, Type, Any import torch import torch.distributed as dist import torch.nn as nn +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.nn.parameter import Parameter +from torch.optim import Optimizer + from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc @@ -11,11 +16,8 @@ from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32 -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.parameter import Parameter -from torch.optim import Optimizer -from typing import Type, Any +from colossalai.logging import get_dist_logger + from ._utils import has_inf_or_nan @@ -82,7 +84,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): :type defaults: dict() """ assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' - + self._logger = get_dist_logger('ShardedOptimV2 logger') self._optim_defaults = defaults # initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters() @@ -136,23 +138,24 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.grad_scaler.update(found_inf) if found_inf: + self._logger.info('found inf during ShardedOptimV2 step') self.zero_grad() return # assign master param pointers to p.data. # We will not trigger data copy here. - for group in self.optim.param_groups: + for group in self.optimizer.param_groups: for p in group['params']: p.data = self.master_params[p] # Now p.data is sharded # So optimizer states are sharded naturally - ret = self.optim.step(*args, **kwargs) + ret = self.optimizer.step(*args, **kwargs) # Copy master param data (fp32) to payload of col_attr (fp16) # TODO() improve efficiency by gathering tensors into a chunk and transfering # a chunk. - for group in self.optim.param_groups: + for group in self.optimizer.param_groups: for p in group['params']: is_param_sharded = p.col_attr.data.is_sharded if not is_param_sharded: @@ -196,7 +199,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._found_overflow.fill_(0.0) # check for overflow - for group in self.optim.param_groups: + for group in self.optimizer.param_groups: for p in group['params']: if has_inf_or_nan(p.grad): self._found_overflow.fill_(1.0) @@ -212,7 +215,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def _unscale_grads(self): assert self.optim_state == OptimState.SCALED - for group in self.optim.param_groups: + for group in self.optimizer.param_groups: for p in group['params']: if p.grad is not None: p.grad.data.div_(self.loss_scale) @@ -222,7 +225,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # We must set grad to None # Because we will judge whether local grad accumulation # is enabled by wheter grad is None - self.optim.zero_grad(set_to_none=True) + self.optimizer.zero_grad(set_to_none=True) def sync_grad(self): pass diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 3236c54e5..e43671eed 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -6,6 +6,7 @@ import torch.distributed as dist from colossalai.logging import get_dist_logger from colossalai.utils import checkpoint from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.nn.optimizer import CPUAdam LOGGER = get_dist_logger('zero_test') @@ -19,16 +20,16 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, use_memory_tracer=False) _ZERO_OPTIMIZER_CONFIG = dict( - optimizer_class=torch.optim.Adam, + optimizer_class=torch.optim.Adam, #CPUAdam cpu_offload=False, - initial_scale=2**32, + initial_scale=2**5, min_scale=1, growth_factor=2, backoff_factor=0.5, growth_interval=1000, hysteresis=2, max_scale=2**32, -) + lr=1e-3) ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), zero=dict( diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py index 5ae36cd04..34590c57c 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py @@ -13,6 +13,7 @@ from colossalai.zero.sharded_optim import ShardedOptimizerV2 from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP from colossalai.nn.optimizer import CPUAdam +from colossalai.zero.sharded_optim._utils import has_inf_or_nan from common import CONFIG, check_sharded_params_padding @@ -71,6 +72,8 @@ def _run_dist(rank, world_size, port, cpu_offload, shard_strategy, use_cpuadam): _run_step(model, optim, data, label, criterion, False) _run_step(zero_model, sharded_optim, data, label, criterion, False) check_sharded_params_padding(model, zero_model, loose=True) + for param in model.parameters(): + assert not has_inf_or_nan(param) # use_cpuadam = True can be used with cpu_offload = False @@ -105,7 +108,4 @@ def test_sharded_optim_v2_cpu_adam(world_size, cpu_offload, shard_strategy, use_ if __name__ == '__main__': - test_sharded_optim_v2_cpu_adam(world_size=2, - cpu_offload=False, - shard_strategy=TensorShardStrategy, - use_cpuadam=True) + test_sharded_optim_v2_cpu_adam(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy, use_cpuadam=True) diff --git a/tests/test_zero_data_parallel/test_zero_engine.py b/tests/test_zero_data_parallel/test_zero_engine.py index cdd2bbc5e..f6d814c69 100644 --- a/tests/test_zero_data_parallel/test_zero_engine.py +++ b/tests/test_zero_data_parallel/test_zero_engine.py @@ -8,6 +8,7 @@ import pytest import colossalai from colossalai.utils import free_port +from colossalai.zero.sharded_optim._utils import has_inf_or_nan import torch.multiprocessing as mp import torch.distributed as dist @@ -32,12 +33,13 @@ def run_dist(rank, world_size, port, parallel_config): colo_model = model_builder(checkpoint=True) torch_model = copy.deepcopy(colo_model).cuda() + torch_model.train() engine, train_dataloader, _, _ = colossalai.initialize(colo_model, optimizer=optimizer_class, criterion=criterion, train_dataloader=train_dataloader) engine.train() - torch_optimizer = optimizer_class(torch_model.parameters()) + torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3) if dist.get_world_size() > 1: torch_model = DDP(torch_model) @@ -66,15 +68,17 @@ def run_dist(rank, world_size, port, parallel_config): engine.step() torch_loss.backward() + + for param in torch_model.parameters(): + if param.grad is not None: + assert not has_inf_or_nan(param.grad) + torch_optimizer.step() i += 1 - # for torch_param, zero_param in zip(torch_model.parameters(), colo_model.parameters()): - # assert torch.allclose(torch_param, zero_param), f"diff {torch_param - zero_param}" - if parallel_config == MP_PARALLEL_CONFIG: check_params(torch_model, colo_model, loose=True) - elif isinstance(colo_model, ShardedModelV2): + elif parallel_config == ZERO_PARALLEL_CONFIG: check_sharded_params_padding(torch_model, colo_model, loose=True)