From a9b8300d54f842c4ae840552e99f2ffadffcc172 Mon Sep 17 00:00:00 2001 From: HELSON Date: Mon, 11 Apr 2022 13:38:51 +0800 Subject: [PATCH] [zero] improve adaptability for not-shard parameters (#708) * adapt post grad hooks for not-shard parameters * adapt optimizer for not-shard parameters * offload gradients for not-replicated parameters --- colossalai/nn/layer/moe/utils.py | 2 +- colossalai/nn/optimizer/cpu_adam.py | 5 +- colossalai/zero/init_ctx/init_context.py | 12 +- .../zero/sharded_model/sharded_model_v2.py | 64 ++++++----- .../zero/sharded_optim/sharded_optim_v2.py | 105 ++++++++---------- tests/test_moe/test_moe_zero_init.py | 4 +- tests/test_moe/test_moe_zero_model.py | 2 +- tests/test_moe/test_moe_zero_optim.py | 24 ++-- tests/test_zero_data_parallel/common.py | 7 +- 9 files changed, 114 insertions(+), 111 deletions(-) diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 3a1258bd1..fd985146a 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts class ForceFP32Parameter(torch.nn.Parameter): def half(self, memory_format=None): - return self + return self.data class NormalNoiseGenerator: diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 475e615ea..084b0cc0b 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -142,6 +142,7 @@ class CPUAdam(torch.optim.Optimizer): beta1, beta2 = group['betas'] if target_device.type == 'cpu': + assert p.data.numel() == p.grad.data.numel(), "parameter and gradient should have the same size" assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" self.cpu_adam_op.adam_update(self.opt_id, state['step'], group['lr'], beta1, beta2, group['eps'], @@ -151,8 +152,8 @@ class CPUAdam(torch.optim.Optimizer): assert state['exp_avg'].device.type == 'cuda', "exp_avg should stay on cuda" assert state['exp_avg_sq'].device.type == 'cuda', "exp_avg should stay on cuda" - bias_correction1 = 1 - beta1 ** state['step'] - bias_correction2 = 1 - beta2 ** state['step'] + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] # adam on cuda self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 1a2fce2da..0284c92f3 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -213,7 +213,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] for param in self.param_list: assert hasattr(param, 'colo_attr') - if not param.colo_attr.param_is_sharded and param.is_replicated: + if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated: dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group) param.colo_attr.remove_torch_payload() @@ -239,9 +239,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.model_numel_tensor += param.numel() - # mark whether the param is replicated - param.is_replicated = self.is_replicated - # convert parameters to half param_half = half_fn(param) param.data = param_half @@ -261,6 +258,13 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload + # mark whether the param is replicated + param.colo_attr.is_replicated = self.is_replicated + + # mark whether the param should keep not sharded + # if True, the param is used as Zero stage 2 + param.colo_attr.keep_not_shard = not self.shard_param + self.param_list.append(param) # We must cast buffers diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index bd3752482..403aaadfa 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -123,7 +123,7 @@ class ShardedModelV2(nn.Module): ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group) ] register_ophooks_recursively(self.module, self._ophook_list) - self.param_hook_mgr = BaseParamHookMgr(self.sharded_params) + self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters())) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.fp32_reduce_scatter = fp32_reduce_scatter @@ -177,8 +177,8 @@ class ShardedModelV2(nn.Module): self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0]) if gpc.get_global_rank() == 0: with open(filename, 'w+') as f: - f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n') - f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n') + f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n') + f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n') f.write('CUDA model data (GB)\n') f.write(str(self._memstats_collector.model_data_list('cuda', 'GB'))) f.write('\n') @@ -254,10 +254,6 @@ class ShardedModelV2(nn.Module): torch.cuda.current_stream().synchronize() self.reducer.free() - # all reduce gradients for unsharded parameters - reduce_list = [p for p in self.unshard_params if p.is_replicated] - bucket_allreduce(reduce_list, self.process_group) - # 3. shard tensors not dealed in the zero hook tensor_list = [] for p in self.sharded_params: @@ -279,15 +275,6 @@ class ShardedModelV2(nn.Module): if not self._require_backward_grad_sync: continue - # move unsharded param grad to saved_grad - if not p.colo_attr.param_is_sharded: - if p.colo_attr.offload_grad: - colo_model_data_move_to_cpu(p.grad) - if p.colo_attr.saved_grad.is_null(): - p.colo_attr.saved_grad.reset_payload(p.grad.data) - else: - p.colo_attr.saved_grad.payload.add_(p.grad.data) - p.grad = None @torch.no_grad() @@ -316,6 +303,18 @@ class ShardedModelV2(nn.Module): assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients' if not self._require_backward_grad_sync: return + + if param.colo_attr.is_replicated: + self._reduce_scatter_handler(param, grad) + else: + self._save_grad(param, grad) + + # used to cheat Pytorch, since we can't return None + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + return empty_grad + + def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None: self.comm_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.comm_stream): new_grad = grad.clone() @@ -334,9 +333,6 @@ class ShardedModelV2(nn.Module): self._reduce_scatter_callback(param, new_grad) orig_grad_data.record_stream(self.comm_stream) torch.cuda.current_stream().wait_stream(self.comm_stream) - empty_grad = torch.empty_like(grad) - free_storage(empty_grad) - return empty_grad def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: assert isinstance(reduced_grad, @@ -345,21 +341,35 @@ class ShardedModelV2(nn.Module): if self.gradient_postdivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. reduced_grad.data.div_(self.gradient_postdivide_factor) - # FIXME(ver217): remove the below line when impl eviction policy + self._save_grad(param, reduced_grad) + + # FIXME(ver217): refactor the below line when impl eviction policy + def _save_grad(self, param: Parameter, grad: torch.Tensor): + # move gradient to cpu if param.colo_attr.offload_grad: - colo_model_data_move_to_cpu(reduced_grad) + colo_model_data_move_to_cpu(grad) + if self.reuse_fp16_shard: + # make parameters point to gradient + assert param.colo_attr.saved_grad.is_null( ), 'Gradien accumulation is not supported when reuse_fp16_shard=True' - param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad) - param.colo_attr.sharded_data_tensor.is_sharded = True - param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload) + + param.colo_attr.saved_grad.reset_payload(grad) + param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param + + if param.colo_attr.is_replicated: + param.colo_attr.sharded_data_tensor.is_sharded = True else: - reduced_grad = cast_tensor_to_fp32(reduced_grad) + + fp32_grad = cast_tensor_to_fp32(grad) + if param.colo_attr.saved_grad.is_null(): - param.colo_attr.saved_grad.reset_payload(reduced_grad) + param.colo_attr.saved_grad.reset_payload(fp32_grad) else: - param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload)) + param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload)) + + # keep saved_grad in HOLD state param.colo_attr.saved_grad.trans_state(TensorState.HOLD) def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index bd708dad3..4196befcd 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -68,9 +68,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5. growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000. hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2. - keep_unsharded (bool, optional): if True, optimizer won't shard unsharded parameters. - In Zero-2, set keep_unsharded to False. - In Zero-3, set keep_unsharded to True. max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32. dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None. mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None. @@ -91,7 +88,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): growth_interval: float = 1000, hysteresis: float = 2, max_scale: int = 2**32, - keep_unsharded: bool = False, dp_process_group: Optional[ProcessGroup] = None, mp_process_group: Optional[ProcessGroup] = None) -> None: assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' @@ -125,10 +121,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device()) self._logger = get_dist_logger("ShardedOptimizerV2") - assert not (keep_unsharded and self._should_move_fp32_shards_h2d), \ - "Keeping unsharded parameters can't be used with hybrid OS placement right now." - self.keep_unshard = keep_unsharded - # Store fp32 param shards self._register_master_weight() @@ -139,6 +131,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): if self._use_memory_tracer: GLOBAL_MODEL_DATA_TRACER.register_optimizer(self) + @property + def loss_scale(self): + return self.grad_scaler.scale.item() + def get_memory_usage(self) -> Tuple[int, int]: """ Get the memory usage of the optimizer. Including master_params (param fp32), momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``) @@ -166,6 +162,22 @@ class ShardedOptimizerV2(ColossalaiOptimizer): return cuda_use, cpu_use + def zero_grad(self, *args, **kwargs): + self._zero_grad() + + def backward(self, loss: Tensor) -> None: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self.model.backward(loss) + + def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: + self.model.backward_by_grad(tensor, grad) + + def clip_grad_norm(self, model: nn.Module, max_norm: float): + if self.optim_state == OptimState.SCALED: + self._unscale_grads() + return super().clip_grad_norm(model, max_norm) + def step(self, *args, **kwargs): self._prepare_grads() self._maybe_move_fp32_shards() @@ -193,26 +205,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._logger.debug( f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!", ranks=[0]) - self._copy_master_param_to_param_fp16() + self._copy_master_model_to_model_fp16() return ret - def backward(self, loss: Tensor) -> None: - loss = self.loss_scale * loss - self.optim_state = OptimState.SCALED - self.model.backward(loss) - - def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: - self.model.backward_by_grad(tensor, grad) - - def clip_grad_norm(self, model: nn.Module, max_norm: float): - if self.optim_state == OptimState.SCALED: - self._unscale_grads() - return super().clip_grad_norm(model, max_norm) - - @property - def loss_scale(self): - return self.grad_scaler.scale.item() - def _check_overflow(self): # clear previous overflow record self._found_overflow.fill_(0.0) @@ -240,9 +235,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): p.grad.data.div_(self.loss_scale) self.optim_state = OptimState.UNSCALED - def zero_grad(self, *args, **kwargs): - self._zero_grad() - def _zero_grad(self, recover_data: bool = False): """zero grad and maybe recover fp16 params When `reuse_fp16_shard` is enabled, @@ -262,13 +254,11 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # p.colo_attr.sharded_data_tensor stores grad now # we have to recover fp16 param reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr() - p.colo_attr.saved_grad.set_null() if recover_data and reuse_fp16_shard: - # We should write like this to trigger ForceFP32Paramter's half method - p.data = self.master_params[p].payload - p.colo_attr.sharded_data_tensor.reset_payload( - colo_model_tensor_clone(p.half(), torch.cuda.current_device())) - p.colo_attr.remove_torch_payload() + self._copy_master_param_to_param_fp16(p) + else: + # release saved gradient + p.colo_attr.saved_grad.set_null() def sync_grad(self): pass @@ -278,14 +268,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer): for group in self.optim.param_groups: for p in group['params']: assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam' - is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded - if not is_param_sharded and not self.keep_unshard: - # Please use keep_unsharded to control whether shard unsharded paramters - # As we only store param shard, we shard it here + shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated + if shard_flag: + # we always shard replicated paramters self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) self.master_params[p] = StatefulTensor( cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device))) - if not is_param_sharded and not self.keep_unshard: + if shard_flag: # In this branch, there's no need to shard param # So we gather here self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) @@ -328,31 +317,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Now p.data is sharded # So optimizer states are sharded naturally - def _copy_master_param_to_param_fp16(self): + def _copy_master_model_to_model_fp16(self): # Copy master param data (fp32) to payload of colo_attr (fp16) # TODO() improve efficiency by gathering tensors into a chunk and transfering # a chunk. for group in self.optim.param_groups: for p in group['params']: - is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded - if not is_param_sharded and not self.keep_unshard: - # We use ZeRO-2 here - # The `p.colo_attr.sharded_data_tensor` saves full fp16 param - # But we only have updated fp32 param shard here - # So we first shard full fp16 param and copy fp32 param shard to it - # Then we will gather them - self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) - # We have to use `copy_payload` instead of `reset_payload` - # Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16 + self._copy_master_param_to_param_fp16(p) - # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.colo_attr.sharded_data_tensor.reset_payload( - colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device)) - p.colo_attr.remove_torch_payload() + def _copy_master_param_to_param_fp16(self, p): + # flush gradient + p.colo_attr.saved_grad.set_null() - if not is_param_sharded and not self.keep_unshard: - # We gather full fp16 param here - self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + # TODO() optimize this line CPU (fp32) -> GPU (fp16) + p.data = self.master_params[p].payload + p.colo_attr.sharded_data_tensor.reset_payload( + colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device)) + p.colo_attr.remove_torch_payload() - self.master_params[p].trans_state(TensorState.HOLD) - p.colo_attr.saved_grad.set_null() + if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: + # We gather full fp16 param here + p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + + self.master_params[p].trans_state(TensorState.HOLD) diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 45dc061c7..309692285 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -71,9 +71,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class): # the parameters in moe experts is not replicated if 'experts' in name: - assert not param.is_replicated + assert not param.colo_attr.is_replicated else: - assert param.is_replicated + assert param.colo_attr.is_replicated if param.colo_attr.param_is_sharded: assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \ diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 336117f7a..d90294adf 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -36,7 +36,7 @@ def run_model_test(enable_autocast, shard_strategy_class): # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): - if not p.colo_attr.param_is_sharded and p.is_replicated: + if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload) model = MoeModel().half() diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index aa1ac57bc..8348e093c 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -48,8 +48,13 @@ def _run_step(model, optimizer, data, label, criterion, grad_handler): @parameterize("cpu_offload", [True]) @parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug +@parameterize("reuse_fp16_shard", [True, False]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio=0.0): +def _run_test_sharded_optim_v2(cpu_offload, + shard_strategy_class, + use_cpuadam, + reuse_fp16_shard, + gpu_margin_mem_ratio=0.0): shard_strategy = shard_strategy_class() if use_cpuadam and cpu_offload is False: return @@ -63,17 +68,15 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g shard_param=True): zero_model = MoeModel() - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - offload_config=dict(device='cpu') if cpu_offload else None, - use_memory_tracer=gpu_margin_mem_ratio > 0.0, - reuse_fp16_shard=use_cpuadam, - ) + zero_model = ShardedModelV2(zero_model, + shard_strategy, + offload_config=dict(device='cpu') if cpu_offload else None, + use_memory_tracer=gpu_margin_mem_ratio > 0.0, + reuse_fp16_shard=reuse_fp16_shard) # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): - if not p.colo_attr.param_is_sharded and p.is_replicated: + if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated: assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device())) model = MoeModel().half() @@ -88,8 +91,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g sharded_optim, cpu_offload=cpu_offload, initial_scale=2**5, - gpu_margin_mem_ratio=gpu_margin_mem_ratio, - keep_unsharded=True) + gpu_margin_mem_ratio=gpu_margin_mem_ratio) amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index a35ed0060..993fed98e 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False): rank = dist.get_rank() for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): # zero_grad = zero_p.grad.clone().to(p.device) - if zero_p.colo_attr.param_is_sharded: + if zero_p.colo_attr.is_replicated: zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device) chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) if rank >= len(chunks): @@ -102,8 +102,9 @@ def check_grads_padding(model, zero_model, loose=False): if zero_grad.size(0) > grad.size(0): zero_grad = zero_grad[:grad.size(0)] else: - grad = p.grad zero_grad = zero_p.colo_attr.saved_grad.payload + grad = p.grad.to(zero_grad.dtype) + assert grad.dtype == zero_grad.dtype assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}' @@ -134,7 +135,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard= if zero_p.size(0) > p.size(0): zero_p = zero_p[:p.size(0)] else: - zero_p = zero_p.colo_attr.sharded_data_tensor.payload + zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device) assert p.dtype == zero_p.dtype assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'