[zero] Update initialize for ZeRO (#458)

* polish code

* shard strategy receive pg in shard() / gather()

* update zero engine

* polish code
pull/455/head
ver217 3 years ago committed by GitHub
parent 642846d6f9
commit a241f61b34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
from typing import Optional
import torch
import torch.distributed as dist
from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
@ -17,9 +18,13 @@ class ZeroHook(BaseOpHook):
A hook to process sharded param for ZeRO method.
"""
def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]):
def __init__(self,
shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector],
process_group: Optional[dist.ProcessGroup] = None):
super().__init__()
self.shard_strategy = shard_strategy
self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = torch.device(f'cuda:{get_current_device()}')
@ -30,7 +35,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
self.shard_strategy.gather(tensor_list)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters():
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
@ -45,7 +50,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
self.shard_strategy.shard(tensor_list)
self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters():
param.col_attr.remove_torch_payload()
@ -54,7 +59,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
self.shard_strategy.gather(tensor_list)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters():
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
@ -80,7 +85,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
self.shard_strategy.shard(tensor_list)
self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters():
param.col_attr.remove_torch_payload()

@ -278,7 +278,10 @@ def initialize(model: nn.Module,
cfg_ = {}
optimizer_config = zero_cfg.get('optimizer_config', None)
model_config = zero_cfg.get('model_config', None)
model, optimizer = convert_to_zero_v2(model, model_config=model_config, optimizer_config=optimizer_config)
model, optimizer = convert_to_zero_v2(model,
optimizer,
model_config=model_config,
optimizer_config=optimizer_config)
logger.info("Initializing ZeRO model and optimizer finished!", ranks=[0])
# FIXME() throw a warning if using zero with MP

@ -1,5 +1,6 @@
from typing import Tuple
import torch
import torch.nn as nn
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.logging import get_dist_logger
@ -11,7 +12,8 @@ from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
@ -34,7 +36,7 @@ def convert_to_zero_v2(model: nn.Module, model_config, optimizer_config) -> Tupl
model_config = dict()
zero_model = ShardedModelV2(model, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, **optimizer_config)
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
return zero_model, zero_optimizer

@ -1,26 +1,21 @@
from abc import ABC, abstractmethod
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
import torch.distributed as dist
from typing import List, Optional
import torch.distributed as dist
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
def __init__(self) -> None:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
Args:
process_group (Optional[dist.ProcessGroup], optional): the process group. Defaults to None.
"""
self.process_group = process_group
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
super().__init__()
@abstractmethod
def shard(self, tensor_list: List[ShardedTensor]):
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass
@abstractmethod
def gather(self, tensor_list: List[ShardedTensor]):
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass

@ -1,18 +1,17 @@
from typing import List
from typing import List, Optional
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten
from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy):
def gather(self, tensor_list: List[ShardedTensor]):
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
if len(tensor_list) == 0:
return
@ -21,15 +20,17 @@ class BucketTensorShardStrategy(TensorShardStrategy):
buffer_list: List[torch.Tensor] = []
tensor_numels = [t.payload.numel() for t in tensor_list]
buffer_size = sum(tensor_numels)
for i in range(self.world_size):
if i == self.local_rank:
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
for i in range(world_size):
if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage
for t in tensor_list:
t.reset_payload(None)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group)
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
buffer_list = [buffer.to(target_device) for buffer in buffer_list]

@ -2,49 +2,44 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
class TensorShardStrategy(BaseShardStrategy):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
super().__init__(process_group)
def shard(self, tensor_list: List[ShardedTensor]):
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
for t in tensor_list:
self._shard_tensor(t)
self._shard_tensor(t, process_group)
def gather(self, tensor_list: List[ShardedTensor]):
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
for t in tensor_list:
self._gather_tensor(t)
self._gather_tensor(t, process_group)
def _shard_tensor(self, t: ShardedTensor):
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
if t.is_sharded:
return
sharded_payload, _ = get_shard(t.payload, self.local_rank, self.world_size)
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload)
t.is_sharded = True
def _gather_tensor(self, t: ShardedTensor):
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
if not t.is_sharded:
return
target_device = t.device
buffer_list = []
payload_numel = t.payload.numel()
for i in range(self.world_size):
if i == self.local_rank:
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
for i in range(world_size):
if i == rank:
buffer_list.append(t.payload.cuda(get_current_device()))
else:
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
torch.distributed.all_gather(buffer_list,
buffer_list[self.local_rank],
group=self.process_group,
async_op=False)
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload)
t.to(target_device)

@ -70,7 +70,8 @@ class ShardedModelV2(nn.Module):
self._iter_cnter = 0
# Register hooks
register_ophooks_recursively(self.module, [ZeroHook(self.shard_strategy, self._memstats_collector)])
register_ophooks_recursively(self.module,
[ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)])
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
@ -145,7 +146,7 @@ class ShardedModelV2(nn.Module):
if self.shard_param:
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
self.shard_strategy.shard([p.col_attr.data])
self.shard_strategy.shard([p.col_attr.data], self.process_group)
for p in self.module.parameters():
p.col_attr.bwd_count = 0
if not p.requires_grad:
@ -229,13 +230,13 @@ class ShardedModelV2(nn.Module):
param.col_attr.fp16_grad = reduced_grad.data
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()])
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()], self.process_group)
prev_params = {}
for p in self.module.parameters():
prev_params[p] = p.data
p.data = p.col_attr.data.payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()])
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()], self.process_group)
for p in self.module.parameters():
p.data = prev_params[p]
return gathered_state_dict

@ -7,6 +7,7 @@ import torch.nn as nn
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32
@ -101,6 +102,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger()
# Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {}
@ -113,12 +115,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self.shard_strategy.shard([p.col_attr.data])
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device)
if not is_param_sharded:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.col_attr.data])
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
def step(self, *args, **kwargs):
# unscale grads if scaled
@ -155,7 +157,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.col_attr.data])
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.data is fp16
@ -164,7 +166,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if not is_param_sharded:
# We gather full fp16 param here
self.shard_strategy.gather([p.col_attr.data])
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
p.data = p.col_attr.data.payload
return ret

@ -16,7 +16,7 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
offload_config=None,
gradient_predivide_factor=1.0,
use_memory_tracer=False,
shard_strategy=TensorShardStrategy)
shard_strategy=TensorShardStrategy())
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
initial_scale=2**5,
@ -25,8 +25,7 @@ _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
backoff_factor=0.5,
growth_interval=1000,
hysteresis=2,
max_scale=2**32,
lr=1e-3)
max_scale=2**32)
ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,),
zero=dict(

@ -7,26 +7,27 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.testing import parameterize
@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')])
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(init_device, shard_strategy):
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(init_device, shard_strategy_class):
for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func()
model_numel_tensor = torch.zeros(1, dtype=torch.int)
with ZeroInitContext(convert_fp16=True,
target_device=init_device,
shard_strategy=shard_strategy(),
shard_strategy=shard_strategy_class(),
shard_param=True,
model_numel_tensor=model_numel_tensor):
model = model_builder(checkpoint=True)

@ -9,22 +9,22 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.test_zero_data_parallel.common import CONFIG, allclose
from colossalai.testing import parameterize
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def run_shard_tensor_with_strategy(shard_strategy, world_size):
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_shard_tensor_with_strategy(shard_strategy_class, world_size):
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.origin_shape) == [world_size * 2, 3]
assert list(t.shape) == [world_size * 2, 3]
shard_strategy = shard_strategy(process_group=None)
shard_strategy = shard_strategy_class()
# test shard strategy
shard_strategy.shard([t])

@ -11,6 +11,8 @@ import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from torchvision.models import resnet50
@ -19,7 +21,7 @@ def run_dist(rank, world_size, port):
# as this model has sync batch normalization
# need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled
zero_config = dict(optimizer_config=dict(optimizer_class=torch.optim.Adam, lr=1e-3))
zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy()))
colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False),
rank=rank,
world_size=world_size,
@ -27,7 +29,11 @@ def run_dist(rank, world_size, port):
port=port,
backend='nccl')
model = resnet50()
with ZeroInitContext(convert_fp16=True,
target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
model = resnet50()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
@ -64,10 +70,6 @@ def run_dist(rank, world_size, port):
'expected the output from different ranks to be the same, but got different values'
# FIXME: enable this test in next PR
@pytest.mark.skip
@pytest.mark.dist
def test_sharded_optim_with_sync_bn():
"""

@ -8,7 +8,6 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
@ -17,8 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params,
check_sharded_params_padding)
from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding)
def run_dist(rank, world_size, port, parallel_config):
@ -35,18 +33,19 @@ def run_dist(rank, world_size, port, parallel_config):
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(convert_fp16=hasattr(gpc.config, 'fp16'),
target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shared_strategy(
gpc.get_group(ParallelMode.DATA)),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
colo_model = model_builder(checkpoint=True)
torch_model = model_builder(checkpoint=True).half()
col_model_deepcopy(colo_model, torch_model)
torch_model = torch_model.cuda().float()
colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
optimizer=optimizer_class,
optimizer=colo_optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
torch_model = model_builder(checkpoint=True).half()
col_model_deepcopy(engine.model, torch_model)
torch_model = torch_model.cuda().float()
engine.train()
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
@ -102,7 +101,6 @@ def test_mp_engine(world_size):
mp.spawn(run_func, nprocs=world_size)
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
def test_zero_engine(world_size):

Loading…
Cancel
Save