[test] make zero engine test really work (#447)

pull/451/head
Jiarui Fang 2022-03-17 17:24:25 +08:00 committed by GitHub
parent bb2790cf0b
commit 0fcfb1e00d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 39 additions and 28 deletions

View File

@ -20,6 +20,7 @@ class CPUAdam(torch.optim.Optimizer):
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA.
"""
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args)
self.opt_id = CPUAdam.optimizer_id
@ -34,7 +35,8 @@ class CPUAdam(torch.optim.Optimizer):
self.cpu_adam_op.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode, simd_log)
def __del__(self):
self.cpu_adam_op.destroy_adam(self.opt_id)
if self.cpu_adam_op:
self.cpu_adam_op.destroy_adam(self.opt_id)
def torch_adam_update(self,
data,
@ -72,7 +74,6 @@ class CPUAdam(torch.optim.Optimizer):
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():

View File

@ -2,9 +2,10 @@ from typing import List
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from .tensor_shard_strategy import TensorShardStrategy

View File

@ -2,6 +2,7 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor

View File

@ -1,9 +1,14 @@
from enum import Enum
from typing import Callable, Dict, Optional, Union
from typing import Dict, Optional, Type, Any
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@ -11,11 +16,8 @@ from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from typing import Type, Any
from colossalai.logging import get_dist_logger
from ._utils import has_inf_or_nan
@ -82,7 +84,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
:type defaults: dict()
"""
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
self._logger = get_dist_logger('ShardedOptimV2 logger')
self._optim_defaults = defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
@ -136,23 +138,24 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.grad_scaler.update(found_inf)
if found_inf:
self._logger.info('found inf during ShardedOptimV2 step')
self.zero_grad()
return
# assign master param pointers to p.data.
# We will not trigger data copy here.
for group in self.optim.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
p.data = self.master_params[p]
# Now p.data is sharded
# So optimizer states are sharded naturally
ret = self.optim.step(*args, **kwargs)
ret = self.optimizer.step(*args, **kwargs)
# Copy master param data (fp32) to payload of col_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for group in self.optim.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
is_param_sharded = p.col_attr.data.is_sharded
if not is_param_sharded:
@ -196,7 +199,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow.fill_(0.0)
# check for overflow
for group in self.optim.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
if has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0)
@ -212,7 +215,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED
for group in self.optim.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
@ -222,7 +225,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True)
self.optimizer.zero_grad(set_to_none=True)
def sync_grad(self):
pass

View File

@ -6,6 +6,7 @@ import torch.distributed as dist
from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.nn.optimizer import CPUAdam
LOGGER = get_dist_logger('zero_test')
@ -19,16 +20,16 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
use_memory_tracer=False)
_ZERO_OPTIMIZER_CONFIG = dict(
optimizer_class=torch.optim.Adam,
optimizer_class=torch.optim.Adam, #CPUAdam
cpu_offload=False,
initial_scale=2**32,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32,
)
lr=1e-3)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict(

View File

@ -13,6 +13,7 @@ from colossalai.zero.sharded_optim import ShardedOptimizerV2
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.optimizer import CPUAdam
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from common import CONFIG, check_sharded_params_padding
@ -71,6 +72,8 @@ def _run_dist(rank, world_size, port, cpu_offload, shard_strategy, use_cpuadam):
_run_step(model, optim, data, label, criterion, False)
_run_step(zero_model, sharded_optim, data, label, criterion, False)
check_sharded_params_padding(model, zero_model, loose=True)
for param in model.parameters():
assert not has_inf_or_nan(param)
# use_cpuadam = True can be used with cpu_offload = False
@ -105,7 +108,4 @@ def test_sharded_optim_v2_cpu_adam(world_size, cpu_offload, shard_strategy, use_
if __name__ == '__main__':
test_sharded_optim_v2_cpu_adam(world_size=2,
cpu_offload=False,
shard_strategy=TensorShardStrategy,
use_cpuadam=True)
test_sharded_optim_v2_cpu_adam(world_size=2, cpu_offload=True, shard_strategy=TensorShardStrategy, use_cpuadam=True)

View File

@ -8,6 +8,7 @@ import pytest
import colossalai
from colossalai.utils import free_port
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
import torch.multiprocessing as mp
import torch.distributed as dist
@ -32,12 +33,13 @@ def run_dist(rank, world_size, port, parallel_config):
colo_model = model_builder(checkpoint=True)
torch_model = copy.deepcopy(colo_model).cuda()
torch_model.train()
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
optimizer=optimizer_class,
criterion=criterion,
train_dataloader=train_dataloader)
engine.train()
torch_optimizer = optimizer_class(torch_model.parameters())
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
if dist.get_world_size() > 1:
torch_model = DDP(torch_model)
@ -66,15 +68,17 @@ def run_dist(rank, world_size, port, parallel_config):
engine.step()
torch_loss.backward()
for param in torch_model.parameters():
if param.grad is not None:
assert not has_inf_or_nan(param.grad)
torch_optimizer.step()
i += 1
# for torch_param, zero_param in zip(torch_model.parameters(), colo_model.parameters()):
# assert torch.allclose(torch_param, zero_param), f"diff {torch_param - zero_param}"
if parallel_config == MP_PARALLEL_CONFIG:
check_params(torch_model, colo_model, loose=True)
elif isinstance(colo_model, ShardedModelV2):
elif parallel_config == ZERO_PARALLEL_CONFIG:
check_sharded_params_padding(torch_model, colo_model, loose=True)