mirror of https://github.com/hpcaitech/ColossalAI
[zero] polish sharded optimizer v2 (#490)
parent
62b0a8d644
commit
a9ecb4b244
|
@ -110,19 +110,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
if self._should_move_fp32_shards_h2d:
|
||||
self._should_move_fp32_shards_h2d = False
|
||||
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
|
||||
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
||||
fp32_shards_used_cuda_margin_mem = 0
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
|
||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||
p.col_attr.offload_fp32_grad = False
|
||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||
self._maybe_move_fp32_shards()
|
||||
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
|
@ -223,3 +211,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
def sync_grad(self):
|
||||
pass
|
||||
|
||||
def _maybe_move_fp32_shards(self):
|
||||
if self._should_move_fp32_shards_h2d:
|
||||
self._should_move_fp32_shards_h2d = False
|
||||
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
|
||||
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
||||
fp32_shards_used_cuda_margin_mem = 0
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
|
||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||
p.col_attr.offload_fp32_grad = False
|
||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||
|
|
Loading…
Reference in New Issue