From cf519dac6a5799b8f314aac6f510e2a98d3af9c6 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 20 Nov 2024 16:36:37 +0800 Subject: [PATCH] [optim] hotfix adam load (#6146) * [optim] hotfix adam load * [checkpointio] fix optimizer async io * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [checkpointio] update test * [checkpointio] update test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../booster/plugin/low_level_zero_plugin.py | 2 +- colossalai/nn/optimizer/cpu_adam.py | 8 + colossalai/testing/comparison.py | 6 +- colossalai/utils/safetensors.py | 141 +++++++++++------- .../test_safetensors_async_io.py | 64 +++++--- 5 files changed, 142 insertions(+), 79 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 12ffe5fe5..761947344 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -142,7 +142,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): from colossalai.utils.safetensors import save_nested f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") - save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]}) + save_nested(f_writer, state_dict) self.async_writers.append(f_writer) else: save_state_dict(state_dict, checkpoint, use_safetensors=False) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 68fb582e5..f10945763 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer): # if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + for group in self.param_groups: + for p in group["params"]: + state = self.state[p] + if "step" in state and isinstance(state["step"], torch.Tensor): + state["step"] = int(state["step"].item()) + def torch_adam_update( self, data, diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 4cbb01163..8f9cce246 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -1,4 +1,4 @@ -from typing import Any, List, OrderedDict, Tuple +from typing import Any, List, OrderedDict import torch import torch.distributed as dist @@ -78,9 +78,7 @@ def check_state_dict_equal( v1 = v1.to(v2.dtype) assert_close_loose(v1, v2) else: - if isinstance(v1, Tuple) and not isinstance(v2, Tuple): - v2 = tuple(v2) - assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}" + assert v1 == v2, f"{v1} not equals to {v2}" def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index ad7d3be77..8b8cb627f 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -1,6 +1,5 @@ # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 import json -import warnings from dataclasses import asdict, dataclass from typing import Dict, List, Optional, Tuple @@ -12,6 +11,26 @@ try: except ModuleNotFoundError: raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") _TYPES_INV = {v: k for k, v in _TYPES.items()} +import io + +from torch.distributed.distributed_c10d import _pickler, _unpickler + + +def _object_to_tensor(obj, device): + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined] + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage).to(device) + return byte_tensor + + +def _tensor_to_object(tensor, tensor_size): + tensor = tensor.cpu() + buf = tensor.numpy().tobytes()[:tensor_size] + return _unpickler(io.BytesIO(buf)).load() @dataclass @@ -28,49 +47,68 @@ class PreparedData: offset: int -def flatten_dict(nested_dict, parent_key="", separator="^"): - """ - Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator. - - nested_dict: The input nested dictionary. - parent_key: The parent key currently being processed. - separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary." - """ - items = [] - for k, v in nested_dict.items(): - new_key = f"{parent_key}{separator}{k}" if parent_key else str(k) - if isinstance(v, dict): - items.extend(flatten_dict(v, new_key, separator).items()) - else: - v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v - items.append((new_key, v)) - - return dict(items) - - -def unflatten_dict(flattened_dict, separator="^"): - """ - Restore a flattened dictionary back to a multi-level nested dictionary. - - flattened_dict: The flattened dictionary. - separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary. - """ - nested_dict = {} - for key, value in flattened_dict.items(): - keys = key.split(separator) - try: - keys[0] = int(keys[0]) - except ValueError: - warnings.warn(f"{key[0]} can't convert to integer") - d = nested_dict - for part in keys[:-1]: - if part not in d: - d[part] = {} - d = d[part] - assert isinstance(value, torch.Tensor) - d[keys[-1]] = value - - return nested_dict +def _cast_to_tensor(obj): + if isinstance(obj, torch.Tensor): + return obj + return _object_to_tensor(obj, "cpu") + + +def _cast_to_object(tensor: torch.Tensor): + return _tensor_to_object(tensor, tensor.numel() * tensor.element_size()) + + +def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]: + flat_dict = {} + non_tensor_keys = [] + if "state" in state_dict: + # 3-level dict + states = state_dict["state"] + else: + # 2-level dict, usually for optimizer state dict shard + states = state_dict + + for idx, d in states.items(): + for k, v in d.items(): + nested_key = f"state{seperator}{idx}{seperator}{k}" + if not isinstance(v, torch.Tensor): + non_tensor_keys.append(nested_key) + flat_dict[nested_key] = _cast_to_tensor(v) + if "param_groups" in state_dict: + flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"]) + non_tensor_keys.append("param_groups") + if len(non_tensor_keys) > 0: + metadata = {"non_tensor_keys": non_tensor_keys} + else: + metadata = None + return flat_dict, metadata + + +def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."): + state_dict = {} + if metadata is not None: + non_tensor_keys = json.loads(metadata["non_tensor_keys"]) + else: + non_tensor_keys = [] + flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()} + if "param_groups" in flat_dict: + # 3-level dict + state_dict["param_groups"] = flat_dict.pop("param_groups") + state_dict["state"] = {} + states = state_dict["state"] + else: + # 2-level dict, usually for optimizer state dict shard + states = state_dict + + for k, v in flat_dict.items(): + parts = k.split(seperator) + assert len(parts) == 3 and parts[0] == "state" + idx = int(parts[1]) + key = parts[2] + if idx not in states: + states[idx] = {} + states[idx][key] = v + + return state_dict def prepare( @@ -124,10 +162,8 @@ def save( f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) -def save_nested( - f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None -) -> None: - flatten_data = flatten_dict(state_dict) +def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None: + flatten_data, metadata = _flatten_optim_state_dict(state_dict) save(f_writer, flatten_data, metadata) @@ -154,10 +190,5 @@ def load_flat(checkpoint_path): with safe_open(checkpoint_path, framework="pt") as f: metadata = f.metadata() state_dict_load = load_file(checkpoint_path) - state_dict = unflatten_dict(state_dict_load) - if metadata is None: - return state_dict - metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items())) - combined_state_dict = {"state": state_dict} - combined_state_dict.update(metadata) - return combined_state_dict + state_dict = _unflatten_optim_state_dict(state_dict_load, metadata) + return state_dict diff --git a/tests/test_checkpoint_io/test_safetensors_async_io.py b/tests/test_checkpoint_io/test_safetensors_async_io.py index 31c69e961..521ec10bd 100644 --- a/tests/test_checkpoint_io/test_safetensors_async_io.py +++ b/tests/test_checkpoint_io/test_safetensors_async_io.py @@ -1,9 +1,9 @@ import tempfile -from copy import deepcopy import torch +from safetensors.torch import load_file -from colossalai.utils.safetensors import load_flat, save_nested +from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested try: from tensornvme.async_file_io import AsyncFileWriter @@ -11,17 +11,29 @@ except ModuleNotFoundError: raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") from colossalai.testing import check_state_dict_equal +from colossalai.utils import get_current_device def test_save_load(): with tempfile.TemporaryDirectory() as tempdir: optimizer_state_dict = { - 0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, - 1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, - 2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, - } - # group_dict = {"param_groups": [0, 1, 2]} - group_dict = { + "state": { + 0: { + "step": torch.tensor(1.0), + "exp_avg": torch.rand((1024, 1024)), + "exp_avg_sq": torch.rand((1024, 1024)), + }, + 1: { + "step": torch.tensor(1.0), + "exp_avg": torch.rand((1024, 1024)), + "exp_avg_sq": torch.rand((1024, 1024)), + }, + 2: { + "step": torch.tensor(1.0), + "exp_avg": torch.rand((1024, 1024)), + "exp_avg_sq": torch.rand((1024, 1024)), + }, + }, "param_groups": [ { "lr": 0.001, @@ -94,22 +106,26 @@ def test_save_load(): 61, ], } - ] + ], } - metadata = deepcopy(group_dict) + optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") - - save_nested(f_writer, optimizer_state_dict, metadata) + save_nested(f_writer, optimizer_state_dict) f_writer.sync_before_step() f_writer.synchronize() f_writer.fp.close() - load_state_dict = load_flat(optimizer_saved_path) - state_dict = load_state_dict["state"] - group = {"param_groups": load_state_dict["param_groups"]} - check_state_dict_equal(optimizer_state_dict, state_dict) - check_state_dict_equal(group_dict, group) + check_state_dict_equal(load_state_dict, optimizer_state_dict) + + optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors" + f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread") + save_nested(f_writer, optimizer_state_dict["state"]) + f_writer.sync_before_step() + f_writer.synchronize() + f_writer.fp.close() + load_state_dict_shard = load_flat(optimizer_shard_saved_path) + check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"]) model_state_dict = { "module.weight0": torch.rand((1024, 1024)), @@ -118,10 +134,20 @@ def test_save_load(): } model_saved_path = f"{tempdir}/save_model.safetensors" f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") - save_nested(f_writer, model_state_dict) + save(f_writer, model_state_dict) f_writer.sync_before_step() f_writer.synchronize() f_writer.fp.close() + load_state_dict = load_file(model_saved_path) + check_state_dict_equal(model_state_dict, load_state_dict) - load_state_dict = load_flat(model_saved_path) + model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()} + model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()} + model_saved_path = f"{tempdir}/save_model_cuda.safetensors" + f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") + move_and_save(f_writer, model_state_dict_cuda, model_state_pinned) + f_writer.sync_before_step() + f_writer.synchronize() + f_writer.fp.close() + load_state_dict = load_file(model_saved_path) check_state_dict_equal(model_state_dict, load_state_dict)