From 17e73e62cc5cd057d8bf22ac182329e6c965de86 Mon Sep 17 00:00:00 2001 From: HELSON Date: Sun, 3 Apr 2022 22:02:11 +0800 Subject: [PATCH] [hotfix] fix bugs for unsharded parameters when restore data (#664) --- .../zero/sharded_optim/sharded_optim_v2.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 181f72931..b9252eb94 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -132,7 +132,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Store fp32 param shards self._register_master_weight() - self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", + self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]) self._use_memory_tracer = self.model.use_memory_tracer @@ -185,13 +185,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._point_param_fp16_to_master_param() self._logger.debug( - f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", + f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!", ranks=[0]) ret = self.optim.step(*args, **kwargs) self._logger.debug( - f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", + f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory, {self.get_memory_usage()[1] / 1e6} MB CUDA Memory!", ranks=[0]) self._copy_master_param_to_param_fp16() return ret @@ -264,8 +264,14 @@ class ShardedOptimizerV2(ColossalaiOptimizer): reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr() p.colo_attr.saved_grad.set_null() if recover_data and reuse_fp16_shard: + # We should write like this to trigger ForceFP32Paramter's half method + p.data = self.master_params[p].payload p.colo_attr.sharded_data_tensor.reset_payload( - colo_model_tensor_clone(self.master_params[p].payload.half(), torch.cuda.current_device())) + colo_model_tensor_clone(p.half(), torch.cuda.current_device())) + + if not p.colo_attr.param_is_sharded: + # FIXME(hhc): add hook for unsharded parameters + p.data = p.colo_attr.sharded_data_tensor.payload def sync_grad(self): pass @@ -281,7 +287,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # As we only store param shard, we shard it here self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) self.master_params[p] = StatefulTensor( - cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)) + cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device))) if not is_param_sharded and not self.keep_unshard: # In this branch, there's no need to shard param # So we gather here