fix move fp32 shards (#1604)

pull/1605/head
ver217 2022-09-16 17:33:16 +08:00 committed by GitHub
parent eac1b79371
commit c9e8ce67b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -288,6 +288,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups:
for p in group['params']:
if p.colo_attr.saved_grad.is_null():
continue
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())