mirror of https://github.com/hpcaitech/ColossalAI
[zero] improve adaptability for not-shard parameters (#708)
* adapt post grad hooks for not-shard parameters * adapt optimizer for not-shard parameters * offload gradients for not-replicated parameterspull/712/head
parent
ab8c6b4a0e
commit
a9b8300d54
|
@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts
|
|||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
|
||||
def half(self, memory_format=None):
|
||||
return self
|
||||
return self.data
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
|
|
|
@ -142,6 +142,7 @@ class CPUAdam(torch.optim.Optimizer):
|
|||
beta1, beta2 = group['betas']
|
||||
|
||||
if target_device.type == 'cpu':
|
||||
assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size"
|
||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
|
@ -151,8 +152,8 @@ class CPUAdam(torch.optim.Optimizer):
|
|||
assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda"
|
||||
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
bias_correction1 = 1 - beta1**state['step']
|
||||
bias_correction2 = 1 - beta2**state['step']
|
||||
|
||||
# adam on cuda
|
||||
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||
|
|
|
@ -213,7 +213,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||
for param in self.param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
if not param.colo_attr.param_is_sharded and param.is_replicated:
|
||||
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
|
@ -239,9 +239,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
self.model_numel_tensor += param.numel()
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.is_replicated = self.is_replicated
|
||||
|
||||
# convert parameters to half
|
||||
param_half = half_fn(param)
|
||||
param.data = param_half
|
||||
|
@ -261,6 +258,13 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.colo_attr.is_replicated = self.is_replicated
|
||||
|
||||
# mark whether the param should keep not sharded
|
||||
# if True, the param is used as Zero stage 2
|
||||
param.colo_attr.keep_not_shard = not self.shard_param
|
||||
|
||||
self.param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
|
|
|
@ -123,7 +123,7 @@ class ShardedModelV2(nn.Module):
|
|||
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
|
||||
]
|
||||
register_ophooks_recursively(self.module, self._ophook_list)
|
||||
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||
|
@ -177,8 +177,8 @@ class ShardedModelV2(nn.Module):
|
|||
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
|
||||
if gpc.get_global_rank() == 0:
|
||||
with open(filename, 'w+') as f:
|
||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
|
||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
|
||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||
f.write('CUDA model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
|
||||
f.write('\n')
|
||||
|
@ -254,10 +254,6 @@ class ShardedModelV2(nn.Module):
|
|||
torch.cuda.current_stream().synchronize()
|
||||
self.reducer.free()
|
||||
|
||||
# all reduce gradients for unsharded parameters
|
||||
reduce_list = [p for p in self.unshard_params if p.is_replicated]
|
||||
bucket_allreduce(reduce_list, self.process_group)
|
||||
|
||||
# 3. shard tensors not dealed in the zero hook
|
||||
tensor_list = []
|
||||
for p in self.sharded_params:
|
||||
|
@ -279,15 +275,6 @@ class ShardedModelV2(nn.Module):
|
|||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
|
||||
# move unsharded param grad to saved_grad
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
if p.colo_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(p.grad)
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
p.colo_attr.saved_grad.reset_payload(p.grad.data)
|
||||
else:
|
||||
p.colo_attr.saved_grad.payload.add_(p.grad.data)
|
||||
|
||||
p.grad = None
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -316,6 +303,18 @@ class ShardedModelV2(nn.Module):
|
|||
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
||||
if not self._require_backward_grad_sync:
|
||||
return
|
||||
|
||||
if param.colo_attr.is_replicated:
|
||||
self._reduce_scatter_handler(param, grad)
|
||||
else:
|
||||
self._save_grad(param, grad)
|
||||
|
||||
# used to cheat Pytorch, since we can't return None
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
return empty_grad
|
||||
|
||||
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
new_grad = grad.clone()
|
||||
|
@ -334,9 +333,6 @@ class ShardedModelV2(nn.Module):
|
|||
self._reduce_scatter_callback(param, new_grad)
|
||||
orig_grad_data.record_stream(self.comm_stream)
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
return empty_grad
|
||||
|
||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||
assert isinstance(reduced_grad,
|
||||
|
@ -345,21 +341,35 @@ class ShardedModelV2(nn.Module):
|
|||
if self.gradient_postdivide_factor > 1:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||
# FIXME(ver217): remove the below line when impl eviction policy
|
||||
self._save_grad(param, reduced_grad)
|
||||
|
||||
# FIXME(ver217): refactor the below line when impl eviction policy
|
||||
def _save_grad(self, param: Parameter, grad: torch.Tensor):
|
||||
# move gradient to cpu
|
||||
if param.colo_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(reduced_grad)
|
||||
colo_model_data_move_to_cpu(grad)
|
||||
|
||||
if self.reuse_fp16_shard:
|
||||
# make parameters point to gradient
|
||||
|
||||
assert param.colo_attr.saved_grad.is_null(
|
||||
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad)
|
||||
|
||||
param.colo_attr.saved_grad.reset_payload(grad)
|
||||
param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param
|
||||
|
||||
if param.colo_attr.is_replicated:
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||
param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload)
|
||||
else:
|
||||
reduced_grad = cast_tensor_to_fp32(reduced_grad)
|
||||
|
||||
fp32_grad = cast_tensor_to_fp32(grad)
|
||||
|
||||
if param.colo_attr.saved_grad.is_null():
|
||||
param.colo_attr.saved_grad.reset_payload(reduced_grad)
|
||||
param.colo_attr.saved_grad.reset_payload(fp32_grad)
|
||||
else:
|
||||
param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload))
|
||||
param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload))
|
||||
|
||||
# keep saved_grad in HOLD state
|
||||
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
|
|
|
@ -68,9 +68,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
||||
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
||||
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
||||
keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters.
|
||||
In Zero-2, set keep_unsharded to False.
|
||||
In Zero-3, set keep_unsharded to True.
|
||||
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
|
||||
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
|
||||
|
@ -91,7 +88,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
growth_interval: float = 1000,
|
||||
hysteresis: float = 2,
|
||||
max_scale: int = 2**32,
|
||||
keep_unsharded: bool = False,
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||
|
@ -125,10 +121,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
|
||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||
|
||||
assert not (keep_unsharded and self._should_move_fp32_shards_h2d), \
|
||||
"Keeping unsharded parameters can't be used with hybrid OS placement right now."
|
||||
self.keep_unshard = keep_unsharded
|
||||
|
||||
# Store fp32 param shards
|
||||
self._register_master_weight()
|
||||
|
||||
|
@ -139,6 +131,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
if self._use_memory_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale.item()
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
""" Get the memory usage of the optimizer. Including master_params (param fp32),
|
||||
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
|
||||
|
@ -166,6 +162,22 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
return cuda_use, cpu_use
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self._zero_grad()
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
self.model.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||
self.model.backward_by_grad(tensor, grad)
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._prepare_grads()
|
||||
self._maybe_move_fp32_shards()
|
||||
|
@ -193,26 +205,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self._logger.debug(
|
||||
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
self._copy_master_param_to_param_fp16()
|
||||
self._copy_master_model_to_model_fp16()
|
||||
return ret
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
self.model.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||
self.model.backward_by_grad(tensor, grad)
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale.item()
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(0.0)
|
||||
|
@ -240,9 +235,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
p.grad.data.div_(self.loss_scale)
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self._zero_grad()
|
||||
|
||||
def _zero_grad(self, recover_data: bool = False):
|
||||
"""zero grad and maybe recover fp16 params
|
||||
When `reuse_fp16_shard` is enabled,
|
||||
|
@ -262,13 +254,11 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# p.colo_attr.sharded_data_tensor stores grad now
|
||||
# we have to recover fp16 param
|
||||
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
if recover_data and reuse_fp16_shard:
|
||||
# We should write like this to trigger ForceFP32Paramter's half method
|
||||
p.data = self.master_params[p].payload
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||
p.colo_attr.remove_torch_payload()
|
||||
self._copy_master_param_to_param_fp16(p)
|
||||
else:
|
||||
# release saved gradient
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
||||
def sync_grad(self):
|
||||
pass
|
||||
|
@ -278,14 +268,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
# Please use keep_unsharded to control whether shard unsharded paramters
|
||||
# As we only store param shard, we shard it here
|
||||
shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated
|
||||
if shard_flag:
|
||||
# we always shard replicated paramters
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = StatefulTensor(
|
||||
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
if shard_flag:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
@ -328,31 +317,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# Now p.data is sharded
|
||||
# So optimizer states are sharded naturally
|
||||
|
||||
def _copy_master_param_to_param_fp16(self):
|
||||
def _copy_master_model_to_model_fp16(self):
|
||||
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||
# a chunk.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
# We use ZeRO-2 here
|
||||
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
||||
# But we only have updated fp32 param shard here
|
||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
||||
# Then we will gather them
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
# We have to use `copy_payload` instead of `reset_payload`
|
||||
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
||||
self._copy_master_param_to_param_fp16(p)
|
||||
|
||||
def _copy_master_param_to_param_fp16(self, p):
|
||||
# flush gradient
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.data = self.master_params[p].payload
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
|
||||
p.colo_attr.remove_torch_payload()
|
||||
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
# We gather full fp16 param here
|
||||
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
self.master_params[p].trans_state(TensorState.HOLD)
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
|
|
@ -71,9 +71,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
|||
|
||||
# the parameters in moe experts is not replicated
|
||||
if 'experts' in name:
|
||||
assert not param.is_replicated
|
||||
assert not param.colo_attr.is_replicated
|
||||
else:
|
||||
assert param.is_replicated
|
||||
assert param.colo_attr.is_replicated
|
||||
|
||||
if param.colo_attr.param_is_sharded:
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
|
|
|
@ -36,7 +36,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
|||
|
||||
# check whether parameters are identical in ddp
|
||||
for name, p in zero_model.named_parameters():
|
||||
if not p.colo_attr.param_is_sharded and p.is_replicated:
|
||||
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)
|
||||
|
||||
model = MoeModel().half()
|
||||
|
|
|
@ -48,8 +48,13 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler):
|
|||
|
||||
@parameterize("cpu_offload", [True])
|
||||
@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
|
||||
@parameterize("reuse_fp16_shard", [True, False])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0):
|
||||
def _run_test_sharded_optim_v2(cpu_offload,
|
||||
shard_strategy_class,
|
||||
use_cpuadam,
|
||||
reuse_fp16_shard,
|
||||
gpu_margin_mem_ratio=0.0):
|
||||
shard_strategy = shard_strategy_class()
|
||||
if use_cpuadam and cpu_offload is False:
|
||||
return
|
||||
|
@ -63,17 +68,15 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
|||
shard_param=True):
|
||||
zero_model = MoeModel()
|
||||
|
||||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
reuse_fp16_shard=use_cpuadam,
|
||||
)
|
||||
reuse_fp16_shard=reuse_fp16_shard)
|
||||
|
||||
# check whether parameters are identical in ddp
|
||||
for name, p in zero_model.named_parameters():
|
||||
if not p.colo_attr.param_is_sharded and p.is_replicated:
|
||||
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
|
||||
|
||||
model = MoeModel().half()
|
||||
|
@ -88,8 +91,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
|||
sharded_optim,
|
||||
cpu_offload=cpu_offload,
|
||||
initial_scale=2**5,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
keep_unsharded=True)
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False)
|
||||
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
|
||||
|
|
|
@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
|
|||
rank = dist.get_rank()
|
||||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||
if zero_p.colo_attr.param_is_sharded:
|
||||
if zero_p.colo_attr.is_replicated:
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
|
@ -102,8 +102,9 @@ def check_grads_padding(model, zero_model, loose=False):
|
|||
if zero_grad.size(0) > grad.size(0):
|
||||
zero_grad = zero_grad[:grad.size(0)]
|
||||
else:
|
||||
grad = p.grad
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload
|
||||
grad = p.grad.to(zero_grad.dtype)
|
||||
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}'
|
||||
|
||||
|
@ -134,7 +135,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
|
|||
if zero_p.size(0) > p.size(0):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
else:
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device)
|
||||
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
|
||||
|
|
Loading…
Reference in New Issue