Browse Source

[optim] hotfix adam load (#6146)

* [optim] hotfix adam load

* [checkpointio] fix optimizer async io

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [checkpointio] update test

* [checkpointio] update test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6148/head
Hongxin Liu 2 days ago committed by GitHub
parent
commit
cf519dac6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      colossalai/booster/plugin/low_level_zero_plugin.py
  2. 8
      colossalai/nn/optimizer/cpu_adam.py
  3. 6
      colossalai/testing/comparison.py
  4. 141
      colossalai/utils/safetensors.py
  5. 64
      tests/test_checkpoint_io/test_safetensors_async_io.py

2
colossalai/booster/plugin/low_level_zero_plugin.py

@ -142,7 +142,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]}) save_nested(f_writer, state_dict)
self.async_writers.append(f_writer) self.async_writers.append(f_writer)
else: else:
save_state_dict(state_dict, checkpoint, use_safetensors=False) save_state_dict(state_dict, checkpoint, use_safetensors=False)

8
colossalai/nn/optimizer/cpu_adam.py

@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification # if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
if "step" in state and isinstance(state["step"], torch.Tensor):
state["step"] = int(state["step"].item())
def torch_adam_update( def torch_adam_update(
self, self,
data, data,

6
colossalai/testing/comparison.py

@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict, Tuple from typing import Any, List, OrderedDict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -78,9 +78,7 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype) v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2) assert_close_loose(v1, v2)
else: else:
if isinstance(v1, Tuple) and not isinstance(v2, Tuple): assert v1 == v2, f"{v1} not equals to {v2}"
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):

141
colossalai/utils/safetensors.py

@ -1,6 +1,5 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json import json
import warnings
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -12,6 +11,26 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
_TYPES_INV = {v: k for k, v in _TYPES.items()} _TYPES_INV = {v: k for k, v in _TYPES.items()}
import io
from torch.distributed.distributed_c10d import _pickler, _unpickler
def _object_to_tensor(obj, device):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage).to(device)
return byte_tensor
def _tensor_to_object(tensor, tensor_size):
tensor = tensor.cpu()
buf = tensor.numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()
@dataclass @dataclass
@ -28,49 +47,68 @@ class PreparedData:
offset: int offset: int
def flatten_dict(nested_dict, parent_key="", separator="^"): def _cast_to_tensor(obj):
""" if isinstance(obj, torch.Tensor):
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator. return obj
return _object_to_tensor(obj, "cpu")
nested_dict: The input nested dictionary.
parent_key: The parent key currently being processed.
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary." def _cast_to_object(tensor: torch.Tensor):
""" return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())
items = []
for k, v in nested_dict.items():
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k) def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
if isinstance(v, dict): flat_dict = {}
items.extend(flatten_dict(v, new_key, separator).items()) non_tensor_keys = []
else: if "state" in state_dict:
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v # 3-level dict
items.append((new_key, v)) states = state_dict["state"]
else:
return dict(items) # 2-level dict, usually for optimizer state dict shard
states = state_dict
def unflatten_dict(flattened_dict, separator="^"): for idx, d in states.items():
""" for k, v in d.items():
Restore a flattened dictionary back to a multi-level nested dictionary. nested_key = f"state{seperator}{idx}{seperator}{k}"
if not isinstance(v, torch.Tensor):
flattened_dict: The flattened dictionary. non_tensor_keys.append(nested_key)
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary. flat_dict[nested_key] = _cast_to_tensor(v)
""" if "param_groups" in state_dict:
nested_dict = {} flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"])
for key, value in flattened_dict.items(): non_tensor_keys.append("param_groups")
keys = key.split(separator) if len(non_tensor_keys) > 0:
try: metadata = {"non_tensor_keys": non_tensor_keys}
keys[0] = int(keys[0]) else:
except ValueError: metadata = None
warnings.warn(f"{key[0]} can't convert to integer") return flat_dict, metadata
d = nested_dict
for part in keys[:-1]:
if part not in d: def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
d[part] = {} state_dict = {}
d = d[part] if metadata is not None:
assert isinstance(value, torch.Tensor) non_tensor_keys = json.loads(metadata["non_tensor_keys"])
d[keys[-1]] = value else:
non_tensor_keys = []
return nested_dict flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()}
if "param_groups" in flat_dict:
# 3-level dict
state_dict["param_groups"] = flat_dict.pop("param_groups")
state_dict["state"] = {}
states = state_dict["state"]
else:
# 2-level dict, usually for optimizer state dict shard
states = state_dict
for k, v in flat_dict.items():
parts = k.split(seperator)
assert len(parts) == 3 and parts[0] == "state"
idx = int(parts[1])
key = parts[2]
if idx not in states:
states[idx] = {}
states[idx][key] = v
return state_dict
def prepare( def prepare(
@ -124,10 +162,8 @@ def save(
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
def save_nested( def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None flatten_data, metadata = _flatten_optim_state_dict(state_dict)
) -> None:
flatten_data = flatten_dict(state_dict)
save(f_writer, flatten_data, metadata) save(f_writer, flatten_data, metadata)
@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
with safe_open(checkpoint_path, framework="pt") as f: with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
state_dict_load = load_file(checkpoint_path) state_dict_load = load_file(checkpoint_path)
state_dict = unflatten_dict(state_dict_load) state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
if metadata is None: return state_dict
return state_dict
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
combined_state_dict = {"state": state_dict}
combined_state_dict.update(metadata)
return combined_state_dict

64
tests/test_checkpoint_io/test_safetensors_async_io.py

@ -1,9 +1,9 @@
import tempfile import tempfile
from copy import deepcopy
import torch import torch
from safetensors.torch import load_file
from colossalai.utils.safetensors import load_flat, save_nested from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
try: try:
from tensornvme.async_file_io import AsyncFileWriter from tensornvme.async_file_io import AsyncFileWriter
@ -11,17 +11,29 @@ except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
from colossalai.testing import check_state_dict_equal from colossalai.testing import check_state_dict_equal
from colossalai.utils import get_current_device
def test_save_load(): def test_save_load():
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = { optimizer_state_dict = {
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, "state": {
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, 0: {
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, "step": torch.tensor(1.0),
} "exp_avg": torch.rand((1024, 1024)),
# group_dict = {"param_groups": [0, 1, 2]} "exp_avg_sq": torch.rand((1024, 1024)),
group_dict = { },
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [ "param_groups": [
{ {
"lr": 0.001, "lr": 0.001,
@ -94,22 +106,26 @@ def test_save_load():
61, 61,
], ],
} }
] ],
} }
metadata = deepcopy(group_dict)
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict)
save_nested(f_writer, optimizer_state_dict, metadata)
f_writer.sync_before_step() f_writer.sync_before_step()
f_writer.synchronize() f_writer.synchronize()
f_writer.fp.close() f_writer.fp.close()
load_state_dict = load_flat(optimizer_saved_path) load_state_dict = load_flat(optimizer_saved_path)
state_dict = load_state_dict["state"] check_state_dict_equal(load_state_dict, optimizer_state_dict)
group = {"param_groups": load_state_dict["param_groups"]}
check_state_dict_equal(optimizer_state_dict, state_dict) optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
check_state_dict_equal(group_dict, group) f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
model_state_dict = { model_state_dict = {
"module.weight0": torch.rand((1024, 1024)), "module.weight0": torch.rand((1024, 1024)),
@ -118,10 +134,20 @@ def test_save_load():
} }
model_saved_path = f"{tempdir}/save_model.safetensors" model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, model_state_dict) save(f_writer, model_state_dict)
f_writer.sync_before_step() f_writer.sync_before_step()
f_writer.synchronize() f_writer.synchronize()
f_writer.fp.close() f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict)
load_state_dict = load_flat(model_saved_path) model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict) check_state_dict_equal(model_state_dict, load_state_dict)

Loading…
Cancel
Save