From a9ecb4b2441afda8bd9ac3890990b624bb9d28f2 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 22 Mar 2022 15:53:48 +0800 Subject: [PATCH] [zero] polish sharded optimizer v2 (#490) --- .../zero/sharded_optim/sharded_optim_v2.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 8ccec731d..2109d4499 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -110,19 +110,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) def step(self, *args, **kwargs): - if self._should_move_fp32_shards_h2d: - self._should_move_fp32_shards_h2d = False - available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio - fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param - fp32_shards_used_cuda_margin_mem = 0 - for group in self.optim.param_groups: - for p in group['params']: - shard_mem = self.master_params[p].numel() * self.master_params[p].element_size() - if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: - self.master_params[p] = self.master_params[p].to(torch.cuda.current_device()) - p.grad.data = p.grad.data.to(torch.cuda.current_device()) - p.col_attr.offload_fp32_grad = False - fp32_shards_used_cuda_margin_mem += shard_mem + self._maybe_move_fp32_shards() # unscale grads if scaled if self.optim_state == OptimState.SCALED: @@ -223,3 +211,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def sync_grad(self): pass + + def _maybe_move_fp32_shards(self): + if self._should_move_fp32_shards_h2d: + self._should_move_fp32_shards_h2d = False + available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio + fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param + fp32_shards_used_cuda_margin_mem = 0 + for group in self.optim.param_groups: + for p in group['params']: + shard_mem = self.master_params[p].numel() * self.master_params[p].element_size() + if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: + self.master_params[p] = self.master_params[p].to(torch.cuda.current_device()) + p.grad.data = p.grad.data.to(torch.cuda.current_device()) + p.col_attr.offload_fp32_grad = False + fp32_shards_used_cuda_margin_mem += shard_mem