diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 8f2232393..96d5902e8 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -80,9 +80,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper): tp_process_group: Optional[ProcessGroup] = None, # if using tp forced_dtype: Optional[torch.dtype] = None): - # TODO: - # 1. state_dict for checkpoint IO - super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() @@ -528,9 +525,12 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in state.items(): if isinstance(v, torch.Tensor) and k != 'step': working_param = self._param_store.master_to_working_param[id(param)] - gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)] - dist.all_gather(gather_tensor, v, group=self.dp_pg) - param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param) + gather_tensor = [ + torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size) + ] + dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg) + param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as( + working_param).cpu() zero_state[param][k] = param_state states_dict = self._pack_state(zero_state) @@ -553,7 +553,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper): if padding_size > 0: v = torch.nn.functional.pad(v, [0, padding_size]) v_list = v.split(v.numel() // self._world_size) - zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach() + device = 'cpu' if self._cpu_offload else 'cuda' + zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].to(device).detach() self.optim.load_state_dict(zero_state_dict) zero_state_dict = dict() @@ -585,9 +586,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper): for k, v in states.items(): if isinstance(v, torch.Tensor) and k != 'step': - state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)] - dist.all_gather(state_tensor, v, group=self.dp_pg) - state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param) + state_tensor = [torch.zeros(v.shape, device='cuda', dtype=v.dtype) for _ in range(self._world_size)] + dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg) + state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as( + working_param).cpu() current_block_size += state_tensor.numel() current_block[k] = state_tensor diff --git a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py index a94e8d42c..3faa395b5 100644 --- a/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_low_level_zero_checkpoint_io.py @@ -16,19 +16,21 @@ from colossalai.testing import ( ) +# stage 1 and 2 process the optimizer/mode the same way +# only test 2 is fine @clear_cache_before_run() @parameterize('stage', [2]) @parameterize('shard', [True, False]) -def check_low_level_zero_checkpointIO(stage: int, shard: bool): - plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32) +@parameterize('offload', [False, True]) +def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool): + plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload) booster = Booster(plugin=plugin) model = resnet18() criterion = lambda x: x.mean() optimizer = HybridAdam((model.parameters()), lr=0.001) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) - x = torch.randn(4, 3, 224, 224) - x = x.to('cuda') + x = torch.randn(1, 3, 224, 224, device='cuda') output = model(x) loss = criterion(output) booster.backward(loss, optimizer) @@ -50,15 +52,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool): check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) def run_dist(rank, world_size, port): colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost') check_low_level_zero_checkpointIO() + torch.cuda.empty_cache() @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_low_level_zero_checkpointIO(): spawn(run_dist, 2) diff --git a/tests/test_zero/test_low_level/test_zero_ckpt.py b/tests/test_zero/test_low_level/test_zero_ckpt.py index 23356fe71..ab811c6b4 100644 --- a/tests/test_zero/test_low_level/test_zero_ckpt.py +++ b/tests/test_zero/test_low_level/test_zero_ckpt.py @@ -37,7 +37,7 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): atol = 4e-3 a = a.detach().to(dtype) - b = b.detach().to(dtype) + b = b.detach().to(dtype).to(a.device) assert_close(a, b, rtol=rtol, atol=atol)