Revert "[zero] update sharded optim and fix zero init ctx" (#456)

* Revert "polish code"

This reverts commit 8cf7ff08cf.

* Revert "rename variables"

This reverts commit e99af94ab8.

* Revert "remove surplus imports"

This reverts commit 46add4a5c5.

* Revert "update sharded optim and fix zero init ctx"

This reverts commit 57567ee768.
pull/457/head
Jiarui Fang 2022-03-18 15:22:43 +08:00 committed by GitHub
parent 8cf7ff08cf
commit e2e9f82588
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 161 additions and 161 deletions

View File

@ -5,7 +5,7 @@ import argparse
import os import os
import pprint import pprint
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -21,13 +21,13 @@ from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.global_variables import moe_env from colossalai.global_variables import moe_env
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence, from colossalai.utils import (accumulate_gradient, get_current_device, is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param) sync_model_param)
from colossalai.zero import convert_to_zero_v2 from colossalai.zero import convert_to_zero_v2
from colossalai.engine.ophooks import BaseOpHook
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
@ -217,8 +217,8 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
verbose=verbose) verbose=verbose)
def initialize(model: nn.Module, def initialize(model: Union[Callable, nn.Module],
optimizer: Optimizer, optimizer: Union[Type[Optimizer], Optimizer],
criterion: Optional[_Loss] = None, criterion: Optional[_Loss] = None,
train_dataloader: Optional[Iterable] = None, train_dataloader: Optional[Iterable] = None,
test_dataloader: Optional[Iterable] = None, test_dataloader: Optional[Iterable] = None,
@ -278,10 +278,12 @@ def initialize(model: nn.Module,
cfg_ = {} cfg_ = {}
optimizer_config = zero_cfg.get('optimizer_config', None) optimizer_config = zero_cfg.get('optimizer_config', None)
model_config = zero_cfg.get('model_config', None) model_config = zero_cfg.get('model_config', None)
model, optimizer = convert_to_zero_v2(model, model_config=model_config, optimizer_config=optimizer_config) model, optimizer = convert_to_zero_v2(model_builder=model,
model_config=model_config,
optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
# FIXME() throw a warning if using zero with MP #FIXME() throw a warning if using zero with MP
if gpc.get_world_size(ParallelMode.MODEL) > 1: if gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0]) logger.warning("ZeRO currently has not been tested with model parallelism.", ranks=[0])
else: else:

View File

@ -1,17 +1,22 @@
from typing import Tuple from typing import Callable
import torch
import torch.nn as nn import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel from torch.optim import Optimizer
from colossalai.logging import get_dist_logger
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
from torch.optim import Optimizer from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.core import global_context as gpc
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.logging import get_dist_logger
from .sharded_model import ShardedModel from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer from .sharded_optim import ShardedOptimizer
def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: def convert_to_zero_v2(model_builder: Callable, model_config, optimizer_config) -> (ShardedModelV2, ShardedOptimizerV2):
""" """
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
@ -26,6 +31,9 @@ def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tupl
logger = get_dist_logger('convert_to_zero_v2') logger = get_dist_logger('convert_to_zero_v2')
# FIXME() pass shard strategy from config
shard_strategy = TensorShardStrategy()
logger.info(f'optimizer_config is {optimizer_config}') logger.info(f'optimizer_config is {optimizer_config}')
if optimizer_config is None: if optimizer_config is None:
optimizer_config = dict() optimizer_config = dict()
@ -33,7 +41,18 @@ def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tupl
if model_config is None: if model_config is None:
model_config = dict() model_config = dict()
zero_model = ShardedModelV2(model, **model_config) if isinstance(model_builder, nn.Module):
model = model_builder
elif isinstance(model_builder, Callable):
with ZeroInitContext(convert_fp16='fp16' in gpc.config,
target_device=torch.cuda.current_device(),
shard_strategy=shard_strategy,
shard_param=model_config.get('shard_param', True)):
model = model_builder()
else:
raise TypeError(f"convert_to_zero_v2 dose not support model_builder of type {type(convert_to_zero_v2)}")
zero_model = ShardedModelV2(model, shard_strategy=shard_strategy, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, **optimizer_config) zero_optimizer = ShardedOptimizerV2(zero_model, **optimizer_config)
return zero_model, zero_optimizer return zero_model, zero_optimizer

View File

@ -4,7 +4,6 @@ import torch
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
# Inserts _post_init_method at the end of init method # Inserts _post_init_method at the end of init method
@ -159,10 +158,3 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
# if param.col_attr.grad and self.shard_grad: # if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) # self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
for buffer in module.buffers():
buffer.data = buffer.data.to(device=torch.cuda.current_device())
if self.convert_fp16:
buffer.data = cast_tensor_to_fp16(buffer.data)

View File

@ -1,6 +1,6 @@
import functools import functools
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional, Type from typing import Any, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -28,13 +28,14 @@ class ShardedModelV2(nn.Module):
def __init__(self, def __init__(self,
module: nn.Module, module: nn.Module,
shard_strategy: Type[BaseShardStrategy], shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None, process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None, reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25, reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False, fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None, offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0, gradient_predivide_factor: Optional[float] = 1.0,
shard_param: bool = True,
use_memory_tracer: bool = False): use_memory_tracer: bool = False):
r""" r"""
A demo to reconfigure zero1 shared_model. A demo to reconfigure zero1 shared_model.
@ -43,23 +44,23 @@ class ShardedModelV2(nn.Module):
super().__init__() super().__init__()
self.logger = get_dist_logger() self.logger = get_dist_logger()
# We force users to use ZeroInitContext
sharded = []
unsharded = []
for param in module.parameters():
assert hasattr(param, 'col_attr'), 'You must use ZeroInitContext to init your module first.'
sharded.append(param.col_attr.param_is_sharded)
unsharded.append(not param.col_attr.param_is_sharded)
assert all(sharded) or all(
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded nwo.'
self.shard_param = all(sharded)
self.module = module
self.process_group = process_group or gpc.get_group(ParallelMode.DATA) self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
self.world_size = dist.get_world_size(self.process_group) self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group) self.rank = dist.get_rank(self.process_group)
# Cast module to fp16 and cuda, in case user didn't use ZeroInitContext
self.module = module.half().cuda()
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.shard_param = shard_param
# In case user didn't use ZeroInitContext
for param in self.module.parameters():
if not hasattr(param, 'col_attr'):
param.col_attr = ShardedParamV2(param, process_group, rm_torch_payload=True)
if self.shard_param:
self.shard_strategy.shard([param.col_attr.data])
# Init Memory Statistics Collector # Init Memory Statistics Collector
self._use_memory_tracer = use_memory_tracer self._use_memory_tracer = use_memory_tracer

View File

@ -1,20 +1,23 @@
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Dict, Optional, Type, Any
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
from colossalai.logging import get_dist_logger
from ._utils import has_inf_or_nan from ._utils import has_inf_or_nan
@ -27,7 +30,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self, def __init__(self,
sharded_model: ShardedModelV2, sharded_model: ShardedModelV2,
optimizer: Optimizer, optimizer_class: Type[Optimizer],
cpu_offload: bool = False, cpu_offload: bool = False,
initial_scale: float = 2**32, initial_scale: float = 2**32,
min_scale: float = 1, min_scale: float = 1,
@ -37,15 +40,16 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis: float = 2, hysteresis: float = 2,
max_scale: int = 2**32, max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None: mp_process_group: Optional[ProcessGroup] = None,
**defaults: Any) -> None:
""" """
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the :param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors. shard strategy provided by sharded model to shard param fp32 tensors.
:type sharded_model: sharded_model :type sharded_model: sharded_model
:param optimizer_class: A class type of Optimizer :param optimizer_class: A class type of Optimizer
:type optimizer_class: Type[Optimizer] :type optimizer_class: Type[Optimizer]
:param cpu_offload: is offloading the optimizer states to CPU. :param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool :type cpu_offload: bool
@ -80,8 +84,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
:type defaults: dict() :type defaults: dict()
""" """
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
self._logger = get_dist_logger('ShardedOptimV2 logger')
self._optim_defaults = defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
super().__init__(optimizer) self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
super().__init__(self.optimizer)
self.shard_strategy = sharded_model.shard_strategy self.shard_strategy = sharded_model.shard_strategy
self.model: ShardedModelV2 = sharded_model self.model: ShardedModelV2 = sharded_model
if cpu_offload and not sharded_model.cpu_offload: if cpu_offload and not sharded_model.cpu_offload:
@ -105,7 +114,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards # Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {} self.master_params: Dict[Parameter, Tensor] = {}
for group in self.optim.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.data.is_sharded is_param_sharded = p.col_attr.data.is_sharded
@ -135,18 +144,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# assign master param pointers to p.data. # assign master param pointers to p.data.
# We will not trigger data copy here. # We will not trigger data copy here.
for group in self.optim.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
p.data = self.master_params[p] p.data = self.master_params[p]
# Now p.data is sharded # Now p.data is sharded
# So optimizer states are sharded naturally # So optimizer states are sharded naturally
ret = self.optim.step(*args, **kwargs) ret = self.optimizer.step(*args, **kwargs)
# Copy master param data (fp32) to payload of col_attr (fp16) # Copy master param data (fp32) to payload of col_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering # TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk. # a chunk.
for group in self.optim.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
is_param_sharded = p.col_attr.data.is_sharded is_param_sharded = p.col_attr.data.is_sharded
if not is_param_sharded: if not is_param_sharded:
@ -190,7 +199,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow.fill_(0.0) self._found_overflow.fill_(0.0)
# check for overflow # check for overflow
for group in self.optim.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
if has_inf_or_nan(p.grad): if has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0) self._found_overflow.fill_(1.0)
@ -206,7 +215,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _unscale_grads(self): def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED assert self.optim_state == OptimState.SCALED
for group in self.optim.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
if p.grad is not None: if p.grad is not None:
p.grad.data.div_(self.loss_scale) p.grad.data.div_(self.loss_scale)
@ -216,7 +225,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# We must set grad to None # We must set grad to None
# Because we will judge whether local grad accumulation # Because we will judge whether local grad accumulation
# is enabled by wheter grad is None # is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True) self.optimizer.zero_grad(set_to_none=True)
def sync_grad(self): def sync_grad(self):
pass pass

View File

@ -2,10 +2,11 @@ from functools import partial
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.nn.optimizer import CPUAdam
LOGGER = get_dist_logger('zero_test') LOGGER = get_dist_logger('zero_test')
@ -15,18 +16,20 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
fp32_reduce_scatter=False, fp32_reduce_scatter=False,
offload_config=None, offload_config=None,
gradient_predivide_factor=1.0, gradient_predivide_factor=1.0,
use_memory_tracer=False, shard_param=True,
shard_strategy=TensorShardStrategy) use_memory_tracer=False)
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False, _ZERO_OPTIMIZER_CONFIG = dict(
initial_scale=2**5, optimizer_class=torch.optim.Adam, #CPUAdam
min_scale=1, cpu_offload=False,
growth_factor=2, initial_scale=2**5,
backoff_factor=0.5, min_scale=1,
growth_interval=1000, growth_factor=2,
hysteresis=2, backoff_factor=0.5,
max_scale=2**32, growth_interval=1000,
lr=1e-3) hysteresis=2,
max_scale=2**32,
lr=1e-3)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict( zero=dict(

View File

@ -1,13 +1,15 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import copy
from asyncio.log import logger
from functools import partial from functools import partial
import colossalai import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize from colossalai.logging import get_dist_logger
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
@ -18,30 +20,36 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import CONFIG, check_grads_padding, run_fwd_bwd from common import CONFIG, check_grads_padding, run_fwd_bwd
from colossalai.testing import parameterize
@parameterize("enable_autocast", [True]) @parameterize("enable_autocast", [True])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("use_zero_init_ctx", [True])
def run_model_test(enable_autocast, shard_strategy_class): @parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(enable_autocast, use_zero_init_ctx, shard_strategy, logger):
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy_class() shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func() model_builder, train_dataloader, _, _, criterion = get_components_func()
rm_torch_payload_on_the_fly = False rm_torch_payload_on_the_fly = False
with ZeroInitContext(convert_fp16=True, if use_zero_init_ctx:
target_device=torch.cuda.current_device(), with ZeroInitContext(convert_fp16=True,
shard_strategy=shard_strategy, target_device=torch.device(f'cpu:0'),
shard_param=True, shard_strategy=shard_strategy,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly): shard_param=True,
zero_model = model_builder(checkpoint=True) rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly):
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True) zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
model = model_builder(checkpoint=True).half() model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model) col_model_deepcopy(zero_model, model)
model = model.cuda() model = model.cuda()
else:
model = model_builder(checkpoint=True).half().cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
model = DDP(model) model = DDP(model)
@ -55,10 +63,15 @@ def run_model_test(enable_autocast, shard_strategy_class):
check_grads_padding(model, zero_model, loose=True) check_grads_padding(model, zero_model, loose=True)
# logger.debug('overall cuda ', zero_model._memstats_collector._overall_cuda)
# logger.debug('model cuda ', zero_model._memstats_collector._model_data_cuda)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_model_test() logger = get_dist_logger()
logger.set_level('DEBUG')
run_model_test(logger=logger)
@pytest.mark.dist @pytest.mark.dist

View File

@ -1,3 +1,4 @@
import copy
from functools import partial from functools import partial
import colossalai import colossalai
@ -5,18 +6,15 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.nn.optimizer import CPUAdam
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.testing import parameterize
from common import CONFIG, check_sharded_params_padding from common import CONFIG, check_sharded_params_padding
@ -40,42 +38,36 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@parameterize("cpu_offload", [True, False]) @parameterize("cpu_offload", [True, False])
@parameterize("use_cpuadam", [True, False]) @parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam): def _run_test_sharded_optim_v2(cpu_offload, shard_strategy, use_cpuadam):
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy_class() shard_strategy = shard_strategy()
if use_cpuadam and cpu_offload is False: if use_cpuadam and cpu_offload is False:
return return
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model(checkpoint=True).cuda()
with ZeroInitContext(convert_fp16=True, zero_model = ShardedModelV2(copy.deepcopy(model),
target_device=torch.device(f'cpu:0'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=False):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model,
shard_strategy, shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None) offload_config=dict(device='cpu') if cpu_offload else None)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda().float()
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
model = DDP(model) model = DDP(model)
lr = 1e-3
if use_cpuadam: if use_cpuadam:
optimizer_class = CPUAdam optim = torch.optim.Adam(model.parameters(), lr=lr)
optim = optimizer_class(model.parameters(), lr=1e-3) sharded_optim = ShardedOptimizerV2(zero_model, CPUAdam, cpu_offload=cpu_offload, initial_scale=2**5, lr=lr)
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) else:
sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, cpu_offload=cpu_offload, initial_scale=2**5) optim = optimizer_class(model.parameters(), lr=lr)
sharded_optim = ShardedOptimizerV2(zero_model,
optimizer_class,
cpu_offload=cpu_offload,
initial_scale=2**5,
lr=lr)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
# FIXME() if i > 5, the unittest will fail #FIXME() if i > 5, the unittest will fail
if i > 3: if i > 3:
break break
data, label = data.cuda(), label.cuda() data, label = data.cuda(), label.cuda()

View File

@ -6,12 +6,12 @@ from functools import partial
import colossalai import colossalai
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from torchvision.models import resnet50 from torchvision.models import resnet50
import torch.distributed as dist
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
@ -64,10 +64,6 @@ def run_dist(rank, world_size, port):
'expected the output from different ranks to be the same, but got different values' 'expected the output from different ranks to be the same, but got different values'
# FIXME: enable this test in next PR
@pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
def test_sharded_optim_with_sync_bn(): def test_sharded_optim_with_sync_bn():
""" """

View File

@ -8,37 +8,24 @@ import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import parameterize
from common import CONFIG from common import CONFIG
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def run_zero_state_dict(shard_strategy_class): def run_zero_state_dict(shard_strategy):
test_models = ['repeated_computed_layers', 'resnet18'] test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy_class() shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model_builder()
with ZeroInitContext(convert_fp16=True, model = model.half().cuda()
target_device=torch.cuda.current_device(), zero_model = ShardedModelV2(deepcopy(model), shard_strategy)
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=False):
zero_model = model_builder(checkpoint=True)
zero_model = ShardedModelV2(zero_model, shard_strategy)
model = model_builder(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
model = model.cuda()
zero_state_dict = zero_model.state_dict() zero_state_dict = zero_model.state_dict()
for key, val in model.state_dict().items(): for key, val in model.state_dict().items():
assert torch.equal(val, zero_state_dict[key]) assert torch.equal(val, zero_state_dict[key])

View File

@ -1,24 +1,21 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import copy
from functools import partial from functools import partial
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
import pytest
import colossalai import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, from tests.components_to_test.registry import non_distributed_component_funcs
check_sharded_params_padding) from common import check_sharded_params_padding, ZERO_PARALLEL_CONFIG, MP_PARALLEL_CONFIG, check_params
def run_dist(rank, world_size, port, parallel_config): def run_dist(rank, world_size, port, parallel_config):
@ -33,16 +30,10 @@ def run_dist(rank, world_size, port, parallel_config):
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'),
target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shared_strategy(
gpc.get_group(ParallelMode.DATA)),
shard_param=True):
colo_model = model_builder(checkpoint=True)
torch_model = model_builder(checkpoint=True).half() colo_model = model_builder(checkpoint=True)
col_model_deepcopy(colo_model, torch_model) torch_model = copy.deepcopy(colo_model).cuda()
torch_model = torch_model.cuda().float() torch_model.train()
engine, train_dataloader, _, _ = colossalai.initialize(colo_model, engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
optimizer=optimizer_class, optimizer=optimizer_class,
criterion=criterion, criterion=criterion,
@ -91,10 +82,6 @@ def run_dist(rank, world_size, port, parallel_config):
check_sharded_params_padding(torch_model, colo_model, loose=True) check_sharded_params_padding(torch_model, colo_model, loose=True)
# FIXME: enable this test in next PR
@pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2, 4]) @pytest.mark.parametrize("world_size", [2, 4])
def test_mp_engine(world_size): def test_mp_engine(world_size):
@ -102,7 +89,6 @@ def test_mp_engine(world_size):
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@pytest.mark.skip
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2]) @pytest.mark.parametrize("world_size", [1, 2])
def test_zero_engine(world_size): def test_zero_engine(world_size):