mirror of https://github.com/hpcaitech/ColossalAI
[zero] polish sharded optimizer v2 (#490)
parent
62b0a8d644
commit
a9ecb4b244
|
@ -110,19 +110,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
if self._should_move_fp32_shards_h2d:
|
self._maybe_move_fp32_shards()
|
||||||
self._should_move_fp32_shards_h2d = False
|
|
||||||
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
|
|
||||||
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
|
||||||
fp32_shards_used_cuda_margin_mem = 0
|
|
||||||
for group in self.optim.param_groups:
|
|
||||||
for p in group['params']:
|
|
||||||
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
|
|
||||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
|
||||||
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
|
||||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
|
||||||
p.col_attr.offload_fp32_grad = False
|
|
||||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
|
||||||
|
|
||||||
# unscale grads if scaled
|
# unscale grads if scaled
|
||||||
if self.optim_state == OptimState.SCALED:
|
if self.optim_state == OptimState.SCALED:
|
||||||
|
@ -223,3 +211,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
|
|
||||||
def sync_grad(self):
|
def sync_grad(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _maybe_move_fp32_shards(self):
|
||||||
|
if self._should_move_fp32_shards_h2d:
|
||||||
|
self._should_move_fp32_shards_h2d = False
|
||||||
|
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
|
||||||
|
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
||||||
|
fp32_shards_used_cuda_margin_mem = 0
|
||||||
|
for group in self.optim.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
|
||||||
|
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||||
|
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
||||||
|
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||||
|
p.col_attr.offload_fp32_grad = False
|
||||||
|
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||||
|
|
Loading…
Reference in New Issue