[zero] improve adaptability for not-shard parameters (#708)

* adapt post grad hooks for not-shard parameters
* adapt optimizer for not-shard parameters
* offload gradients for not-replicated parameters
pull/712/head
HELSON 2022-04-11 13:38:51 +08:00 committed by GitHub
parent ab8c6b4a0e
commit a9b8300d54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 114 additions and 111 deletions

View File

@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self
return self.data
class NormalNoiseGenerator:

View File

@ -142,6 +142,7 @@ class CPUAdam(torch.optim.Optimizer):
beta1, beta2 = group['betas']
if target_device.type == 'cpu':
assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size"
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
@ -151,8 +152,8 @@ class CPUAdam(torch.optim.Optimizer):
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
# adam on cuda
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],

View File

@ -213,7 +213,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
for param in self.param_list:
assert hasattr(param, 'colo_attr')
if not param.colo_attr.param_is_sharded and param.is_replicated:
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
param.colo_attr.remove_torch_payload()
@ -239,9 +239,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.model_numel_tensor += param.numel()
# mark whether the param is replicated
param.is_replicated = self.is_replicated
# convert parameters to half
param_half = half_fn(param)
param.data = param_half
@ -261,6 +258,13 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated
# mark whether the param should keep not sharded
# if True, the param is used as Zero stage 2
param.colo_attr.keep_not_shard = not self.shard_param
self.param_list.append(param)
# We must cast buffers

View File

@ -123,7 +123,7 @@ class ShardedModelV2(nn.Module):
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
]
register_ophooks_recursively(self.module, self._ophook_list)
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
self.fp32_reduce_scatter = fp32_reduce_scatter
@ -177,8 +177,8 @@ class ShardedModelV2(nn.Module):
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
if gpc.get_global_rank() == 0:
with open(filename, 'w+') as f:
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
f.write('\n')
@ -254,10 +254,6 @@ class ShardedModelV2(nn.Module):
torch.cuda.current_stream().synchronize()
self.reducer.free()
# all reduce gradients for unsharded parameters
reduce_list = [p for p in self.unshard_params if p.is_replicated]
bucket_allreduce(reduce_list, self.process_group)
# 3. shard tensors not dealed in the zero hook
tensor_list = []
for p in self.sharded_params:
@ -279,15 +275,6 @@ class ShardedModelV2(nn.Module):
if not self._require_backward_grad_sync:
continue
# move unsharded param grad to saved_grad
if not p.colo_attr.param_is_sharded:
if p.colo_attr.offload_grad:
colo_model_data_move_to_cpu(p.grad)
if p.colo_attr.saved_grad.is_null():
p.colo_attr.saved_grad.reset_payload(p.grad.data)
else:
p.colo_attr.saved_grad.payload.add_(p.grad.data)
p.grad = None
@torch.no_grad()
@ -316,6 +303,18 @@ class ShardedModelV2(nn.Module):
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
if not self._require_backward_grad_sync:
return
if param.colo_attr.is_replicated:
self._reduce_scatter_handler(param, grad)
else:
self._save_grad(param, grad)
# used to cheat Pytorch, since we can't return None
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
new_grad = grad.clone()
@ -334,9 +333,6 @@ class ShardedModelV2(nn.Module):
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)
torch.cuda.current_stream().wait_stream(self.comm_stream)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
assert isinstance(reduced_grad,
@ -345,21 +341,35 @@ class ShardedModelV2(nn.Module):
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
# FIXME(ver217): remove the below line when impl eviction policy
self._save_grad(param, reduced_grad)
# FIXME(ver217): refactor the below line when impl eviction policy
def _save_grad(self, param: Parameter, grad: torch.Tensor):
# move gradient to cpu
if param.colo_attr.offload_grad:
colo_model_data_move_to_cpu(reduced_grad)
colo_model_data_move_to_cpu(grad)
if self.reuse_fp16_shard:
# make parameters point to gradient
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad)
param.colo_attr.sharded_data_tensor.is_sharded = True
param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload)
param.colo_attr.saved_grad.reset_payload(grad)
param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
else:
reduced_grad = cast_tensor_to_fp32(reduced_grad)
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.saved_grad.reset_payload(reduced_grad)
param.colo_attr.saved_grad.reset_payload(fp32_grad)
else:
param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload))
param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload))
# keep saved_grad in HOLD state
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':

View File

@ -68,9 +68,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters.
In Zero-2, set keep_unsharded to False.
In Zero-3, set keep_unsharded to True.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
@ -91,7 +88,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval: float = 1000,
hysteresis: float = 2,
max_scale: int = 2**32,
keep_unsharded: bool = False,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
@ -125,10 +121,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
assert not (keep_unsharded and self._should_move_fp32_shards_h2d), \
"Keeping unsharded parameters can't be used with hybrid OS placement right now."
self.keep_unshard = keep_unsharded
# Store fp32 param shards
self._register_master_weight()
@ -139,6 +131,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def get_memory_usage(self) -> Tuple[int, int]:
""" Get the memory usage of the optimizer. Including master_params (param fp32),
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
@ -166,6 +162,22 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
return cuda_use, cpu_use
def zero_grad(self, *args, **kwargs):
self._zero_grad()
def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
def step(self, *args, **kwargs):
self._prepare_grads()
self._maybe_move_fp32_shards()
@ -193,26 +205,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!",
ranks=[0])
self._copy_master_param_to_param_fp16()
self._copy_master_model_to_model_fp16()
return ret
def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
@ -240,9 +235,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED
def zero_grad(self, *args, **kwargs):
self._zero_grad()
def _zero_grad(self, recover_data: bool = False):
"""zero grad and maybe recover fp16 params
When `reuse_fp16_shard` is enabled,
@ -262,13 +254,11 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# p.colo_attr.sharded_data_tensor stores grad now
# we have to recover fp16 param
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
p.colo_attr.saved_grad.set_null()
if recover_data and reuse_fp16_shard:
# We should write like this to trigger ForceFP32Paramter's half method
p.data = self.master_params[p].payload
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
p.colo_attr.remove_torch_payload()
self._copy_master_param_to_param_fp16(p)
else:
# release saved gradient
p.colo_attr.saved_grad.set_null()
def sync_grad(self):
pass
@ -278,14 +268,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded and not self.keep_unshard:
# Please use keep_unsharded to control whether shard unsharded paramters
# As we only store param shard, we shard it here
shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated
if shard_flag:
# we always shard replicated paramters
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
if not is_param_sharded and not self.keep_unshard:
if shard_flag:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
@ -328,31 +317,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Now p.data is sharded
# So optimizer states are sharded naturally
def _copy_master_param_to_param_fp16(self):
def _copy_master_model_to_model_fp16(self):
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded and not self.keep_unshard:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
self._copy_master_param_to_param_fp16(p)
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.remove_torch_payload()
def _copy_master_param_to_param_fp16(self, p):
# flush gradient
p.colo_attr.saved_grad.set_null()
if not is_param_sharded and not self.keep_unshard:
# We gather full fp16 param here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.data = self.master_params[p].payload
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.remove_torch_payload()
self.master_params[p].trans_state(TensorState.HOLD)
p.colo_attr.saved_grad.set_null()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
# We gather full fp16 param here
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p].trans_state(TensorState.HOLD)

View File

@ -71,9 +71,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
# the parameters in moe experts is not replicated
if 'experts' in name:
assert not param.is_replicated
assert not param.colo_attr.is_replicated
else:
assert param.is_replicated
assert param.colo_attr.is_replicated
if param.colo_attr.param_is_sharded:
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \

View File

@ -36,7 +36,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.is_replicated:
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)
model = MoeModel().half()

View File

@ -48,8 +48,13 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler):
@parameterize("cpu_offload", [True])
@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
@parameterize("reuse_fp16_shard", [True, False])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
def _run_test_sharded_optim_v2(cpu_offload,
shard_strategy_class,
use_cpuadam,
reuse_fp16_shard,
gpu_margin_mem_ratio=0.0):
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return
@ -63,17 +68,15 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
shard_param=True):
zero_model = MoeModel()
zero_model = ShardedModelV2(
zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=use_cpuadam,
)
zero_model = ShardedModelV2(zero_model,
shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None,
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
reuse_fp16_shard=reuse_fp16_shard)
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.is_replicated:
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
model = MoeModel().half()
@ -88,8 +91,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
sharded_optim,
cpu_offload=cpu_offload,
initial_scale=2**5,
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
keep_unsharded=True)
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)

View File

@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
# zero_grad = zero_p.grad.clone().to(p.device)
if zero_p.colo_attr.param_is_sharded:
if zero_p.colo_attr.is_replicated:
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
@ -102,8 +102,9 @@ def check_grads_padding(model, zero_model, loose=False):
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
else:
grad = p.grad
zero_grad = zero_p.colo_attr.saved_grad.payload
grad = p.grad.to(zero_grad.dtype)
assert grad.dtype == zero_grad.dtype
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
@ -134,7 +135,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device)
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'