diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 091a5b274..401ff988d 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -288,6 +288,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer): fp32_shards_used_cuda_margin_mem = 0 for group in self.optim.param_groups: for p in group['params']: + if p.colo_attr.saved_grad.is_null(): + continue shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size() if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())