Revert "[zero] update sharded optim and fix zero init ctx" (#456)

* Revert "polish code"

This reverts commit 8cf7ff08cf.

* Revert "rename variables"

This reverts commit e99af94ab8.

* Revert "remove surplus imports"

This reverts commit 46add4a5c5.

* Revert "update sharded optim and fix zero init ctx"

This reverts commit 57567ee768.
pull/457/head
Jiarui Fang 2022-03-18 15:22:43 +08:00 committed by GitHub
parent 8cf7ff08cf
commit e2e9f82588
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 161 additions and 161 deletions

View File

@ -5,7 +5,7 @@ import argparse
import os
import pprint
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
import torch
import torch.nn as nn
@ -21,13 +21,13 @@ from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.global_variables import moe_env
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero_v2
from colossalai.engine.ophooks import BaseOpHook
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
@ -217,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose)
def initialize(model: nn.Module,
optimizer: Optimizer,
def initialize(model: Union[Callable, nn.Module],
optimizer: Union[Type[Optimizer], Optimizer],
criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None,
@ -278,10 +278,12 @@ def initialize(model: nn.Module,
cfg_ = {}
optimizer_config = zero_cfg.get('optimizer_config', None)
model_config = zero_cfg.get('model_config', None)
model, optimizer = convert_to_zero_v2(model, model_config=model_config, optimizer_config=optimizer_config)
model, optimizer = convert_to_zero_v2(model_builder=model,
model_config=model_config,
optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
# FIXME() throw a warning if using zero with MP
#FIXME() throw a warning if using zero with MP
if gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
else:

View File

@ -1,17 +1,22 @@
from typing import Tuple
from typing import Callable
import torch
import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.logging import get_dist_logger
from torch.optim import Optimizer
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from torch.optim import Optimizer
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.core import global_context as gpc
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.logging import get_dist_logger
from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
@ -26,6 +31,9 @@ def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tupl
logger = get_dist_logger('convert_to_zero_v2')
# FIXME() pass shard strategy from config
shard_strategy = TensorShardStrategy()
logger.info(f'optimizer_config is {optimizer_config}')
if optimizer_config is None:
optimizer_config = dict()
@ -33,7 +41,18 @@ def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tupl
if model_config is None:
model_config = dict()
zero_model = ShardedModelV2(model, **model_config)
if isinstance(model_builder, nn.Module):
model = model_builder
elif isinstance(model_builder, Callable):
with ZeroInitContext(convert_fp16='fp16' in gpc.config,
target_device=torch.cuda.current_device(),
shard_strategy=shard_strategy,
shard_param=model_config.get('shard_param', True)):
model = model_builder()
else:
raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}")
zero_model = ShardedModelV2(model, shard_strategy=shard_strategy, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, **optimizer_config)
return zero_model, zero_optimizer

View File

@ -4,7 +4,6 @@ import torch
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2
# Inserts _post_init_method at the end of init method
@ -159,10 +158,3 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
# if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
for buffer in module.buffers():
buffer.data = buffer.data.to(device=torch.cuda.current_device())
if self.convert_fp16:
buffer.data = cast_tensor_to_fp16(buffer.data)

View File

@ -1,6 +1,6 @@
import functools
from collections import OrderedDict
from typing import Any, Optional, Type
from typing import Any, Optional
import torch
import torch.distributed as dist
@ -28,13 +28,14 @@ class ShardedModelV2(nn.Module):
def __init__(self,
module: nn.Module,
shard_strategy: Type[BaseShardStrategy],
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
shard_param: bool = True,
use_memory_tracer: bool = False):
r"""
A demo to reconfigure zero1 shared_model.
@ -43,23 +44,23 @@ class ShardedModelV2(nn.Module):
super().__init__()
self.logger = get_dist_logger()
# We force users to use ZeroInitContext
sharded = []
unsharded = []
for param in module.parameters():
assert hasattr(param, 'col_attr'), 'You must use ZeroInitContext to init your module first.'
sharded.append(param.col_attr.param_is_sharded)
unsharded.append(not param.col_attr.param_is_sharded)
assert all(sharded) or all(
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded nwo.'
self.shard_param = all(sharded)
self.module = module
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
# Cast module to fp16 and cuda, in case user didn't use ZeroInitContext
self.module = module.half().cuda()
self.shard_strategy = shard_strategy
self.shard_param = shard_param
# In case user didn't use ZeroInitContext
for param in self.module.parameters():
if not hasattr(param, 'col_attr'):
param.col_attr = ShardedParamV2(param, process_group, rm_torch_payload=True)
if self.shard_param:
self.shard_strategy.shard([param.col_attr.data])
# Init Memory Statistics Collector
self._use_memory_tracer = use_memory_tracer

View File

@ -1,20 +1,23 @@
from enum import Enum
from typing import Dict, Optional
from typing import Dict, Optional, Type, Any
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
from colossalai.logging import get_dist_logger
from ._utils import has_inf_or_nan
@ -27,7 +30,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self,
sharded_model: ShardedModelV2,
optimizer: Optimizer,
optimizer_class: Type[Optimizer],
cpu_offload: bool = False,
initial_scale: float = 2**32,
min_scale: float = 1,
@ -37,15 +40,16 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis: float = 2,
max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
mp_process_group: Optional[ProcessGroup] = None,
**defaults: Any) -> None:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
:type sharded_model: sharded_model
:param optimizer_class: A class type of Optimizer
:type optimizer_class: Type[Optimizer]
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
@ -80,8 +84,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
:type defaults: dict()
"""
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
self._logger = get_dist_logger('ShardedOptimV2 logger')
self._optim_defaults = defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
super().__init__(optimizer)
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
super().__init__(self.optimizer)
self.shard_strategy = sharded_model.shard_strategy
self.model: ShardedModelV2 = sharded_model
if cpu_offload and not sharded_model.cpu_offload:
@ -105,7 +114,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {}
for group in self.optim.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.data.is_sharded
@ -135,18 +144,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# assign master param pointers to p.data.
# We will not trigger data copy here.
for group in self.optim.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
p.data = self.master_params[p]
# Now p.data is sharded
# So optimizer states are sharded naturally
ret = self.optim.step(*args, **kwargs)
ret = self.optimizer.step(*args, **kwargs)
# Copy master param data (fp32) to payload of col_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for group in self.optim.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
is_param_sharded = p.col_attr.data.is_sharded
if not is_param_sharded:
@ -190,7 +199,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow.fill_(0.0)
# check for overflow
for group in self.optim.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
if has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0)
@ -206,7 +215,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED
for group in self.optim.param_groups:
for group in self.optimizer.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
@ -216,7 +225,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We must set grad to None
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True)
self.optimizer.zero_grad(set_to_none=True)
def sync_grad(self):
pass

View File

@ -2,10 +2,11 @@ from functools import partial
import torch
import torch.distributed as dist
from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.nn.optimizer import CPUAdam
LOGGER = get_dist_logger('zero_test')
@ -15,18 +16,20 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter=False,
offload_config=None,
gradient_predivide_factor=1.0,
use_memory_tracer=False,
shard_strategy=TensorShardStrategy)
shard_param=True,
use_memory_tracer=False)
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32,
lr=1e-3)
_ZERO_OPTIMIZER_CONFIG = dict(
optimizer_class=torch.optim.Adam, #CPUAdam
cpu_offload=False,
initial_scale=2**5,
min_scale=1,
growth_factor=2,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32,
lr=1e-3)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict(

View File

@ -1,13 +1,15 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from asyncio.log import logger
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize
from colossalai.logging import get_dist_logger
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -18,30 +20,36 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.testing import parameterize
@parameterize("enable_autocast", [True])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(enable_autocast, shard_strategy_class):
@parameterize("use_zero_init_ctx", [True])
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(enable_autocast, use_zero_init_ctx, shard_strategy, logger):
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy_class()
shard_strategy = shard_strategy()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
rm_torch_payload_on_the_fly = False
with ZeroInitContext(convert_fp16=True,
target_device=torch.cuda.current_device(),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
if use_zero_init_ctx:
with ZeroInitContext(convert_fp16=True,
target_device=torch.device(f'cpu:0'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
else:
model = model_builder(checkpoint=True).half().cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
model = DDP(model)
@ -55,10 +63,15 @@ def run_model_test(enable_autocast, shard_strategy_class):
check_grads_padding(model, zero_model, loose=True)
# logger.debug('overall cuda ', zero_model._memstats_collector._overall_cuda)
# logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_model_test()
logger = get_dist_logger()
logger.set_level('DEBUG')
run_model_test(logger=logger)
@pytest.mark.dist

View File

@ -1,3 +1,4 @@
import copy
from functools import partial
import colossalai
@ -5,18 +6,15 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.optimizer import CPUAdam
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.testing import parameterize
from common import CONFIG, check_sharded_params_padding
@ -40,42 +38,36 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@parameterize("cpu_offload", [True, False])
@parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy, use_cpuadam):
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy_class()
shard_strategy = shard_strategy()
if use_cpuadam and cpu_offload is False:
return
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(convert_fp16=True,
target_device=torch.device(f'cpu:0'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=False):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model,
model, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model),
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda().float()
if dist.get_world_size() > 1:
model = DDP(model)
lr = 1e-3
if use_cpuadam:
optimizer_class = CPUAdam
optim = optimizer_class(model.parameters(), lr=1e-3)
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, cpu_offload=cpu_offload, initial_scale=2**5)
optim = torch.optim.Adam(model.parameters(), lr=lr)
sharded_optim = ShardedOptimizerV2(zero_model, CPUAdam, cpu_offload=cpu_offload, initial_scale=2**5, lr=lr)
else:
optim = optimizer_class(model.parameters(), lr=lr)
sharded_optim = ShardedOptimizerV2(zero_model,
optimizer_class,
cpu_offload=cpu_offload,
initial_scale=2**5,
lr=lr)
for i, (data, label) in enumerate(train_dataloader):
# FIXME() if i > 5, the unittest will fail
#FIXME() if i > 5, the unittest will fail
if i > 3:
break
data, label = data.cuda(), label.cuda()

View File

@ -6,12 +6,12 @@ from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from torchvision.models import resnet50
import torch.distributed as dist
def run_dist(rank, world_size, port):
@ -64,10 +64,6 @@ def run_dist(rank, world_size, port):
'expected the output from different ranks to be the same, but got different values'
# FIXME: enable this test in next PR
@pytest.mark.skip
@pytest.mark.dist
def test_sharded_optim_with_sync_bn():
"""

View File

@ -8,37 +8,24 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize
from common import CONFIG
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_zero_state_dict(shard_strategy_class):
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def run_zero_state_dict(shard_strategy):
test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy_class()
shard_strategy = shard_strategy()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
with ZeroInitContext(convert_fp16=True,
target_device=torch.cuda.current_device(),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=False):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
model = model_builder()
model = model.half().cuda()
zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
zero_state_dict = zero_model.state_dict()
for key, val in model.state_dict().items():
assert torch.equal(val, zero_state_dict[key])

View File

@ -1,24 +1,21 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
import pytest
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params,
check_sharded_params_padding)
from tests.components_to_test.registry import non_distributed_component_funcs
from common import check_sharded_params_padding, ZERO_PARALLEL_CONFIG, MP_PARALLEL_CONFIG, check_params
def run_dist(rank, world_size, port, parallel_config):
@ -33,16 +30,10 @@ def run_dist(rank, world_size, port, parallel_config):
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'),
target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shared_strategy(
gpc.get_group(ParallelMode.DATA)),
shard_param=True):
colo_model = model_builder(checkpoint=True)
torch_model = model_builder(checkpoint=True).half()
col_model_deepcopy(colo_model, torch_model)
torch_model = torch_model.cuda().float()
colo_model = model_builder(checkpoint=True)
torch_model = copy.deepcopy(colo_model).cuda()
torch_model.train()
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
optimizer=optimizer_class,
criterion=criterion,
@ -91,10 +82,6 @@ def run_dist(rank, world_size, port, parallel_config):
check_sharded_params_padding(torch_model, colo_model, loose=True)
# FIXME: enable this test in next PR
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4])
def test_mp_engine(world_size):
@ -102,7 +89,6 @@ def test_mp_engine(world_size):
mp.spawn(run_func, nprocs=world_size)
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
def test_zero_engine(world_size):