From a241f61b343d0332fca2022af088a85ab0bb5974 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 18 Mar 2022 16:18:31 +0800 Subject: [PATCH] [zero] Update initialize for ZeRO (#458) * polish code * shard strategy receive pg in shard() / gather() * update zero engine * polish code --- colossalai/engine/ophooks/zero_hook.py | 15 ++++++--- colossalai/initialize.py | 5 ++- colossalai/zero/__init__.py | 6 ++-- .../zero/shard_utils/base_shard_strategy.py | 17 ++++------ .../bucket_tensor_shard_strategy.py | 15 ++++----- .../zero/shard_utils/tensor_shard_strategy.py | 31 ++++++++----------- .../zero/sharded_model/sharded_model_v2.py | 9 +++--- .../zero/sharded_optim/sharded_optim_v2.py | 10 +++--- tests/test_zero_data_parallel/common.py | 5 ++- .../test_init_context.py | 11 ++++--- .../test_shard_param.py | 8 ++--- .../test_sharded_optim_with_sync_bn.py | 14 +++++---- .../test_zero_engine.py | 18 +++++------ 13 files changed, 84 insertions(+), 80 deletions(-) diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index a4df0f502..69390c512 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -1,6 +1,7 @@ from typing import Optional import torch +import torch.distributed as dist from colossalai.registry import OPHOOKS from colossalai.utils import get_current_device from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector @@ -17,9 +18,13 @@ class ZeroHook(BaseOpHook): A hook to process sharded param for ZeRO method. """ - def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]): + def __init__(self, + shard_strategy: BaseShardStrategy, + memstarts_collector: Optional[MemStatsCollector], + process_group: Optional[dist.ProcessGroup] = None): super().__init__() self.shard_strategy = shard_strategy + self.process_group = process_group # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU self.computing_device = torch.device(f'cuda:{get_current_device()}') @@ -30,7 +35,7 @@ class ZeroHook(BaseOpHook): for param in module.parameters(): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.data) - self.shard_strategy.gather(tensor_list) + self.shard_strategy.gather(tensor_list, self.process_group) for param in module.parameters(): if param.col_attr.data.device != self.computing_device: param.col_attr.data.to(self.computing_device) @@ -45,7 +50,7 @@ class ZeroHook(BaseOpHook): for param in module.parameters(): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.data) - self.shard_strategy.shard(tensor_list) + self.shard_strategy.shard(tensor_list, self.process_group) for param in module.parameters(): param.col_attr.remove_torch_payload() @@ -54,7 +59,7 @@ class ZeroHook(BaseOpHook): for param in module.parameters(): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.data) - self.shard_strategy.gather(tensor_list) + self.shard_strategy.gather(tensor_list, self.process_group) for param in module.parameters(): if param.col_attr.data.device != self.computing_device: param.col_attr.data.to(self.computing_device) @@ -80,7 +85,7 @@ class ZeroHook(BaseOpHook): for param in module.parameters(): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.data) - self.shard_strategy.shard(tensor_list) + self.shard_strategy.shard(tensor_list, self.process_group) for param in module.parameters(): param.col_attr.remove_torch_payload() diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 9870eda8c..7a5d05d60 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -278,7 +278,10 @@ def initialize(model: nn.Module, cfg_ = {} optimizer_config = zero_cfg.get('optimizer_config', None) model_config = zero_cfg.get('model_config', None) - model, optimizer = convert_to_zero_v2(model, model_config=model_config, optimizer_config=optimizer_config) + model, optimizer = convert_to_zero_v2(model, + optimizer, + model_config=model_config, + optimizer_config=optimizer_config) logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0]) # FIXME() throw a warning if using zero with MP diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index ecb573669..b94bb370c 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -1,5 +1,6 @@ from typing import Tuple +import torch import torch.nn as nn from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.logging import get_dist_logger @@ -11,7 +12,8 @@ from .sharded_model import ShardedModel from .sharded_optim import ShardedOptimizer -def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: +def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, + optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: """ A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading @@ -34,7 +36,7 @@ def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tupl model_config = dict() zero_model = ShardedModelV2(model, **model_config) - zero_optimizer = ShardedOptimizerV2(zero_model, **optimizer_config) + zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config) return zero_model, zero_optimizer diff --git a/colossalai/zero/shard_utils/base_shard_strategy.py b/colossalai/zero/shard_utils/base_shard_strategy.py index ddae476cc..7c2f4c9f6 100644 --- a/colossalai/zero/shard_utils/base_shard_strategy.py +++ b/colossalai/zero/shard_utils/base_shard_strategy.py @@ -1,26 +1,21 @@ from abc import ABC, abstractmethod -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -import torch.distributed as dist from typing import List, Optional +import torch.distributed as dist +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor + class BaseShardStrategy(ABC): - def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None: + def __init__(self) -> None: """Abstract Shard Strategy. Use to shard a tensors on multiple GPUs. - - Args: - process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None. """ - self.process_group = process_group - self.world_size = dist.get_world_size(self.process_group) - self.local_rank = dist.get_rank(self.process_group) super().__init__() @abstractmethod - def shard(self, tensor_list: List[ShardedTensor]): + def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): pass @abstractmethod - def gather(self, tensor_list: List[ShardedTensor]): + def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): pass diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py index c118f7710..7f81a0567 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -1,18 +1,17 @@ -from typing import List +from typing import List, Optional import torch import torch.distributed as dist -from torch._utils import _flatten_dense_tensors as flatten - from colossalai.utils import get_current_device from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from torch._utils import _flatten_dense_tensors as flatten from .tensor_shard_strategy import TensorShardStrategy class BucketTensorShardStrategy(TensorShardStrategy): - def gather(self, tensor_list: List[ShardedTensor]): + def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded] if len(tensor_list) == 0: return @@ -21,15 +20,17 @@ class BucketTensorShardStrategy(TensorShardStrategy): buffer_list: List[torch.Tensor] = [] tensor_numels = [t.payload.numel() for t in tensor_list] buffer_size = sum(tensor_numels) - for i in range(self.world_size): - if i == self.local_rank: + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + for i in range(world_size): + if i == rank: buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) # Release payload here, to decrease peak memory usage for t in tensor_list: t.reset_payload(None) else: buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) - dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group) + dist.all_gather(buffer_list, buffer_list[rank], group=process_group) # Move to target device before splitting buffer # Ensure we utilize maximum PCIE bandwidth buffer_list = [buffer.to(target_device) for buffer in buffer_list] diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 4e4bdaabb..b393d4e88 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -2,49 +2,44 @@ from typing import List, Optional import torch import torch.distributed as dist - +from colossalai.utils import get_current_device from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.utils import get_current_device class TensorShardStrategy(BaseShardStrategy): - def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None: - super().__init__(process_group) - - def shard(self, tensor_list: List[ShardedTensor]): + def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): for t in tensor_list: - self._shard_tensor(t) + self._shard_tensor(t, process_group) - def gather(self, tensor_list: List[ShardedTensor]): + def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None): for t in tensor_list: - self._gather_tensor(t) + self._gather_tensor(t, process_group) - def _shard_tensor(self, t: ShardedTensor): + def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): if t.is_sharded: return - sharded_payload, _ = get_shard(t.payload, self.local_rank, self.world_size) + sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.reset_payload(sharded_payload) t.is_sharded = True - def _gather_tensor(self, t: ShardedTensor): + def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): if not t.is_sharded: return target_device = t.device buffer_list = [] payload_numel = t.payload.numel() - for i in range(self.world_size): - if i == self.local_rank: + world_size = dist.get_world_size(process_group) + rank = dist.get_rank(process_group) + for i in range(world_size): + if i == rank: buffer_list.append(t.payload.cuda(get_current_device())) else: buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device())) - torch.distributed.all_gather(buffer_list, - buffer_list[self.local_rank], - group=self.process_group, - async_op=False) + dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False) gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape) t.reset_payload(gathered_payload) t.to(target_device) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index a6d7af7ec..f1cda2148 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -70,7 +70,8 @@ class ShardedModelV2(nn.Module): self._iter_cnter = 0 # Register hooks - register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)]) + register_ophooks_recursively(self.module, + [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]) self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) @@ -145,7 +146,7 @@ class ShardedModelV2(nn.Module): if self.shard_param: for p in self.module.parameters(): if not p.col_attr.param_is_sharded: - self.shard_strategy.shard([p.col_attr.data]) + self.shard_strategy.shard([p.col_attr.data], self.process_group) for p in self.module.parameters(): p.col_attr.bwd_count = 0 if not p.requires_grad: @@ -229,13 +230,13 @@ class ShardedModelV2(nn.Module): param.col_attr.fp16_grad = reduced_grad.data def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': - self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()]) + self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()], self.process_group) prev_params = {} for p in self.module.parameters(): prev_params[p] = p.data p.data = p.col_attr.data.payload gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) - self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()]) + self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()], self.process_group) for p in self.module.parameters(): p.data = prev_params[p] return gathered_state_dict diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 3ee8b09c2..a4d260ed8 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -7,6 +7,7 @@ import torch.nn as nn from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32 @@ -101,6 +102,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): hysteresis=hysteresis, max_scale=max_scale) self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device()) + self._logger = get_dist_logger() # Store fp32 param shards self.master_params: Dict[Parameter, Tensor] = {} @@ -113,12 +115,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # TODO (ver217): we may not use shard / gather here # Param is no sharded, which means we use ZeRO-2 here # As we only store param shard, we shard it here - self.shard_strategy.shard([p.col_attr.data]) + self.shard_strategy.shard([p.col_attr.data], self.dp_process_group) self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device) if not is_param_sharded: # In this branch, there's no need to shard param # So we gather here - self.shard_strategy.gather([p.col_attr.data]) + self.shard_strategy.gather([p.col_attr.data], self.dp_process_group) def step(self, *args, **kwargs): # unscale grads if scaled @@ -155,7 +157,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # But we only have updated fp32 param shard here # So we first shard full fp16 param and copy fp32 param shard to it # Then we will gather them - self.shard_strategy.shard([p.col_attr.data]) + self.shard_strategy.shard([p.col_attr.data], self.dp_process_group) # We have to use `copy_payload` instead of `reset_payload` # Since p.data is fp32 and p.col_attr.data is fp16 @@ -164,7 +166,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): if not is_param_sharded: # We gather full fp16 param here - self.shard_strategy.gather([p.col_attr.data]) + self.shard_strategy.gather([p.col_attr.data], self.dp_process_group) p.data = p.col_attr.data.payload return ret diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 6721dc8b8..fc95d59b4 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -16,7 +16,7 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, offload_config=None, gradient_predivide_factor=1.0, use_memory_tracer=False, - shard_strategy=TensorShardStrategy) + shard_strategy=TensorShardStrategy()) _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False, initial_scale=2**5, @@ -25,8 +25,7 @@ _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False, backoff_factor=0.5, growth_interval=1000, hysteresis=2, - max_scale=2**32, - lr=1e-3) + max_scale=2**32) ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), zero=dict( diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index 34f510272..0fe0cd19c 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -7,26 +7,27 @@ import colossalai import pytest import torch import torch.multiprocessing as mp +from colossalai.testing import parameterize from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device +from colossalai.utils.memory_tracer.model_data_memtracer import \ + GLOBAL_MODEL_DATA_TRACER from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from tests.components_to_test.registry import non_distributed_component_funcs from common import CONFIG -from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from colossalai.testing import parameterize @parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')]) -@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_model_test(init_device, shard_strategy): +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_model_test(init_device, shard_strategy_class): for get_components_func in non_distributed_component_funcs: model_builder, _, _, _, _ = get_components_func() model_numel_tensor = torch.zeros(1, dtype=torch.int) with ZeroInitContext(convert_fp16=True, target_device=init_device, - shard_strategy=shard_strategy(), + shard_strategy=shard_strategy_class(), shard_param=True, model_numel_tensor=model_numel_tensor): model = model_builder(checkpoint=True) diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index c38f24b50..16cac2c15 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -9,22 +9,22 @@ import pytest import torch import torch.multiprocessing as mp from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.testing import parameterize from colossalai.utils import free_port from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) from colossalai.zero.sharded_param import ShardedParam, ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_zero_data_parallel.common import CONFIG, allclose -from colossalai.testing import parameterize -@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_shard_tensor_with_strategy(shard_strategy, world_size): +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_shard_tensor_with_strategy(shard_strategy_class, world_size): t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) assert list(t.origin_shape) == [world_size * 2, 3] assert list(t.shape) == [world_size * 2, 3] - shard_strategy = shard_strategy(process_group=None) + shard_strategy = shard_strategy_class() # test shard strategy shard_strategy.shard([t]) diff --git a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py index 738382d8c..eef37a734 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py @@ -11,6 +11,8 @@ import torch.multiprocessing as mp from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import free_port +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import TensorShardStrategy from torchvision.models import resnet50 @@ -19,7 +21,7 @@ def run_dist(rank, world_size, port): # as this model has sync batch normalization # need to configure cudnn deterministic so that # randomness of convolution layers will be disabled - zero_config = dict(optimizer_config=dict(optimizer_class=torch.optim.Adam, lr=1e-3)) + zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy())) colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False), rank=rank, world_size=world_size, @@ -27,7 +29,11 @@ def run_dist(rank, world_size, port): port=port, backend='nccl') - model = resnet50() + with ZeroInitContext(convert_fp16=True, + target_device=torch.cuda.current_device(), + shard_strategy=gpc.config.zero.model_config.shard_strategy, + shard_param=True): + model = resnet50() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = torch.nn.CrossEntropyLoss() @@ -64,10 +70,6 @@ def run_dist(rank, world_size, port): 'expected the output from different ranks to be the same, but got different values' -# FIXME: enable this test in next PR - - -@pytest.mark.skip @pytest.mark.dist def test_sharded_optim_with_sync_bn(): """ diff --git a/tests/test_zero_data_parallel/test_zero_engine.py b/tests/test_zero_data_parallel/test_zero_engine.py index 55eb5b9f3..56ad85203 100644 --- a/tests/test_zero_data_parallel/test_zero_engine.py +++ b/tests/test_zero_data_parallel/test_zero_engine.py @@ -8,7 +8,6 @@ import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp -from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext @@ -17,8 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs from torch.nn.parallel import DistributedDataParallel as DDP -from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, - check_sharded_params_padding) +from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding) def run_dist(rank, world_size, port, parallel_config): @@ -35,18 +33,19 @@ def run_dist(rank, world_size, port, parallel_config): model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'), target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shared_strategy( - gpc.get_group(ParallelMode.DATA)), + shard_strategy=gpc.config.zero.model_config.shard_strategy, shard_param=True): colo_model = model_builder(checkpoint=True) - torch_model = model_builder(checkpoint=True).half() - col_model_deepcopy(colo_model, torch_model) - torch_model = torch_model.cuda().float() + colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3) engine, train_dataloader, _, _ = colossalai.initialize(colo_model, - optimizer=optimizer_class, + optimizer=colo_optimizer, criterion=criterion, train_dataloader=train_dataloader) + torch_model = model_builder(checkpoint=True).half() + col_model_deepcopy(engine.model, torch_model) + torch_model = torch_model.cuda().float() + engine.train() torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3) @@ -102,7 +101,6 @@ def test_mp_engine(world_size): mp.spawn(run_func, nprocs=world_size) -@pytest.mark.skip @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) def test_zero_engine(world_size):