[zero] adapt zero for unsharded paramters (Optimizer part) (#601)

pull/632/head^2
HELSON 2022-04-01 20:10:47 +08:00 committed by GitHub
parent 229382c844
commit 055fbf5be6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 208 additions and 44 deletions

View File

@ -6,7 +6,10 @@ import torch.distributed as dist
from colossalai.communication.collective import scatter_object_list
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
from .common import is_using_pp

View File

@ -11,6 +11,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup
from contextlib import AbstractContextManager
def _substitute_init_recursively(cls, func):
@ -88,6 +89,7 @@ class ZeroContextConfig(object):
"""The configuration used to control zero context initialization.
Args:
target_device (torch.device): The device where param data are after exiting the context.
replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
@ -99,8 +101,13 @@ class ZeroContextConfig(object):
See torchvision resnet18. Defaults to False.
"""
def __init__(self, replicated: bool = True, shard_param: bool = False, rm_torch_payload_on_the_fly: bool = False):
def __init__(self,
target_device: torch.device,
replicated: bool = True,
shard_param: bool = False,
rm_torch_payload_on_the_fly: bool = False):
super().__init__()
self.target_device = target_device
self.is_replicated: bool = replicated
self.shard_param: bool = shard_param
self.rm_torch_payload_on_the_fly: bool = rm_torch_payload_on_the_fly
@ -114,7 +121,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
3. Shard the param and grad according to flags.
Args:
target_device (torch.device): The device where param data after exiting the context.
target_device (torch.device): The device where param data are after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
rm_torch_payload_on_the_fly (bool, optional): If set to `True`, remove tensor payload on `param.data` after module init finished.
@ -136,17 +143,22 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
dp_process_group: Optional[ProcessGroup] = None):
super().__init__()
self.target_device = target_device
self.shard_strategy = shard_strategy
self.initialized_param_list = []
self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(replicated=True,
self.config = ZeroContextConfig(target_device=target_device,
replicated=True,
shard_param=shard_param,
rm_torch_payload_on_the_fly=rm_torch_payload_on_the_fly)
ZeroContextMgr().current_context = self
@property
def target_device(self):
return self.config.target_device
@property
def is_replicated(self):
return self.config.is_replicated
@ -235,8 +247,9 @@ class ZeroContextMgr(metaclass=SingletonMeta):
self.current_context.config = old_config
def no_shard_zero_context(is_replicated: bool = True):
return ZeroContextMgr().hijack_context_config(replicated=is_replicated,
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
replicated=is_replicated,
shard_param=False,
rm_torch_payload_on_the_fly=False)

View File

@ -12,13 +12,12 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.shard_utils.tensor_utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
from colossalai.zero.shard_utils.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@ -69,6 +68,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters.
In Zero-2, set keep_unsharded to False.
In Zero-3, set keep_unsharded to True.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
@ -89,6 +91,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval: float = 1000,
hysteresis: float = 2,
max_scale: int = 2**32,
keep_unsharded: bool = False,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
@ -122,24 +125,12 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
# Store fp32 param shards
self.master_params: Dict[Parameter, StatefulTensor] = {}
assert not (keep_unsharded and self._should_move_fp32_shards_h2d), \
"Keeping unsharded parameters can't be used with hybrid OS placement right now."
self.keep_unshard = keep_unsharded
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
if not is_param_sharded:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# Store fp32 param shards
self._register_master_weight()
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
@ -283,6 +274,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def sync_grad(self):
pass
def _register_master_weight(self):
self.master_params: Dict[Parameter, StatefulTensor] = {}
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded and not self.keep_unshard:
# Please use keep_unsharded to control whether shard unsharded paramters
# As we only store param shard, we shard it here
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
if not is_param_sharded and not self.keep_unshard:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
def _maybe_move_fp32_shards(self):
if self._should_move_fp32_shards_h2d:
self._should_move_fp32_shards_h2d = False
@ -328,7 +336,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
if not is_param_sharded and not self.keep_unshard:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
@ -342,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded:
if not is_param_sharded and not self.keep_unshard:
# We gather full fp16 param here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.colo_attr.sharded_data_tensor.payload

View File

@ -42,4 +42,5 @@ def get_training_components():
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
from colossalai.nn.optimizer import HybridAdam
return model_builder, trainloader, testloader, HybridAdam, criterion

View File

@ -76,8 +76,11 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
else:
assert param.is_replicated
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
if param.colo_attr.param_is_sharded:
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
else:
assert param.colo_attr.sharded_data_tensor.payload.device.type == 'cuda'
def _run_dist(rank, world_size, port):

View File

@ -67,7 +67,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("world_size", [2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_moe_zero_model(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -0,0 +1,134 @@
from functools import partial
import colossalai
from colossalai.utils.cuda import get_current_device
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.amp import convert_to_apex_amp
from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize, rerun_on_exception
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.utils import get_current_device
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.context import MOE_CONTEXT
from colossalai.testing import assert_equal_in_group
from tests.test_zero_data_parallel.common import CONFIG, check_sharded_model_params
from tests.test_moe.test_moe_zero_init import MoeModel
def _run_step(model, optimizer, data, label, criterion, grad_handler):
model.train()
optimizer.zero_grad()
if criterion:
y = model(data)
loss = criterion(y, label)
else:
loss = model(data, label)
loss = loss.float()
if isinstance(model, ShardedModelV2):
optimizer.backward(loss)
else:
loss.backward()
if grad_handler is not None:
grad_handler.handle_gradient()
optimizer.step()
@parameterize("cpu_offload", [True, False])
@parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
MOE_CONTEXT.reset_loss()
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(
target_device=torch.device('cpu') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
shard_strategy=shard_strategy,
shard_param=True,
rm_torch_payload_on_the_fly=False):
zero_model = MoeModel()
zero_model = ShardedModelV2(
zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=use_cpuadam,
)
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.is_replicated:
assert_equal_in_group(p.data.to(get_current_device()))
model = MoeModel().half()
col_model_deepcopy(zero_model, model)
model = model.cuda().float()
if use_cpuadam:
optimizer_class = CPUAdam
optim = optimizer_class(model.parameters(), lr=1e-3)
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(zero_model,
sharded_optim,
cpu_offload=cpu_offload,
initial_scale=2**5,
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
keep_unsharded=True)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
apex_grad_handler = MoeGradientHandler(model)
# Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
if 'gate' in n:
p.data = p.float()
p.data.copy_(zp.data)
for i, (data, label) in enumerate(train_dataloader):
if i > 5:
break
data, label = data.cuda(), label.cuda()
_run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler)
_run_step(zero_model, sharded_optim, data, label, criterion, None)
check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
for param in model.parameters():
assert not has_inf_or_nan(param)
def _run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
MOE_CONTEXT.setup(seed=42)
_run_test_sharded_optim_v2()
# use_cpuadam = True can be used with cpu_offload = False
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_moe_zero_optim(world_size):
run_func = partial(_run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_moe_zero_optim(world_size=2)

View File

@ -124,16 +124,18 @@ def check_params_padding(model, zero_model, loose=False):
def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float()
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank].float()
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
if zero_p.colo_attr.param_is_sharded:
if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float()
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
p = chunks[rank].float()
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'