[zero] polish sharded optimizer v2 (#490)

pull/488/head
ver217 2022-03-22 15:53:48 +08:00 committed by GitHub
parent 62b0a8d644
commit a9ecb4b244
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 13 deletions

View File

@ -110,19 +110,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
def step(self, *args, **kwargs):
if self._should_move_fp32_shards_h2d:
self._should_move_fp32_shards_h2d = False
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups:
for p in group['params']:
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.col_attr.offload_fp32_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem
self._maybe_move_fp32_shards()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
@ -223,3 +211,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def sync_grad(self):
pass
def _maybe_move_fp32_shards(self):
if self._should_move_fp32_shards_h2d:
self._should_move_fp32_shards_h2d = False
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups:
for p in group['params']:
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.col_attr.offload_fp32_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem