[optim] hotfix adam load (#6146)

* [optim] hotfix adam load

* [checkpointio] fix optimizer async io

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [checkpointio] update test

* [checkpointio] update test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6148/head
Hongxin Liu 2024-11-20 16:36:37 +08:00 committed by GitHub
parent 5caad13055
commit cf519dac6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 139 additions and 76 deletions

View File

@ -142,7 +142,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]}) save_nested(f_writer, state_dict)
self.async_writers.append(f_writer) self.async_writers.append(f_writer)
else: else:
save_state_dict(state_dict, checkpoint, use_safetensors=False) save_state_dict(state_dict, checkpoint, use_safetensors=False)

View File

@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification # if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
if "step" in state and isinstance(state["step"], torch.Tensor):
state["step"] = int(state["step"].item())
def torch_adam_update( def torch_adam_update(
self, self,
data, data,

View File

@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict, Tuple from typing import Any, List, OrderedDict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -78,9 +78,7 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype) v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2) assert_close_loose(v1, v2)
else: else:
if isinstance(v1, Tuple) and not isinstance(v2, Tuple): assert v1 == v2, f"{v1} not equals to {v2}"
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):

View File

@ -1,6 +1,5 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json import json
import warnings
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -12,6 +11,26 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
_TYPES_INV = {v: k for k, v in _TYPES.items()} _TYPES_INV = {v: k for k, v in _TYPES.items()}
import io
from torch.distributed.distributed_c10d import _pickler, _unpickler
def _object_to_tensor(obj, device):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage).to(device)
return byte_tensor
def _tensor_to_object(tensor, tensor_size):
tensor = tensor.cpu()
buf = tensor.numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()
@dataclass @dataclass
@ -28,49 +47,68 @@ class PreparedData:
offset: int offset: int
def flatten_dict(nested_dict, parent_key="", separator="^"): def _cast_to_tensor(obj):
""" if isinstance(obj, torch.Tensor):
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator. return obj
return _object_to_tensor(obj, "cpu")
nested_dict: The input nested dictionary.
parent_key: The parent key currently being processed.
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
"""
items = []
for k, v in nested_dict.items():
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, separator).items())
else:
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
items.append((new_key, v))
return dict(items)
def unflatten_dict(flattened_dict, separator="^"): def _cast_to_object(tensor: torch.Tensor):
""" return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())
Restore a flattened dictionary back to a multi-level nested dictionary.
flattened_dict: The flattened dictionary.
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
"""
nested_dict = {}
for key, value in flattened_dict.items():
keys = key.split(separator)
try:
keys[0] = int(keys[0])
except ValueError:
warnings.warn(f"{key[0]} can't convert to integer")
d = nested_dict
for part in keys[:-1]:
if part not in d:
d[part] = {}
d = d[part]
assert isinstance(value, torch.Tensor)
d[keys[-1]] = value
return nested_dict def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
flat_dict = {}
non_tensor_keys = []
if "state" in state_dict:
# 3-level dict
states = state_dict["state"]
else:
# 2-level dict, usually for optimizer state dict shard
states = state_dict
for idx, d in states.items():
for k, v in d.items():
nested_key = f"state{seperator}{idx}{seperator}{k}"
if not isinstance(v, torch.Tensor):
non_tensor_keys.append(nested_key)
flat_dict[nested_key] = _cast_to_tensor(v)
if "param_groups" in state_dict:
flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"])
non_tensor_keys.append("param_groups")
if len(non_tensor_keys) > 0:
metadata = {"non_tensor_keys": non_tensor_keys}
else:
metadata = None
return flat_dict, metadata
def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
state_dict = {}
if metadata is not None:
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
else:
non_tensor_keys = []
flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()}
if "param_groups" in flat_dict:
# 3-level dict
state_dict["param_groups"] = flat_dict.pop("param_groups")
state_dict["state"] = {}
states = state_dict["state"]
else:
# 2-level dict, usually for optimizer state dict shard
states = state_dict
for k, v in flat_dict.items():
parts = k.split(seperator)
assert len(parts) == 3 and parts[0] == "state"
idx = int(parts[1])
key = parts[2]
if idx not in states:
states[idx] = {}
states[idx][key] = v
return state_dict
def prepare( def prepare(
@ -124,10 +162,8 @@ def save(
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
def save_nested( def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None flatten_data, metadata = _flatten_optim_state_dict(state_dict)
) -> None:
flatten_data = flatten_dict(state_dict)
save(f_writer, flatten_data, metadata) save(f_writer, flatten_data, metadata)
@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
with safe_open(checkpoint_path, framework="pt") as f: with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
state_dict_load = load_file(checkpoint_path) state_dict_load = load_file(checkpoint_path)
state_dict = unflatten_dict(state_dict_load) state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
if metadata is None: return state_dict
return state_dict
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
combined_state_dict = {"state": state_dict}
combined_state_dict.update(metadata)
return combined_state_dict

View File

@ -1,9 +1,9 @@
import tempfile import tempfile
from copy import deepcopy
import torch import torch
from safetensors.torch import load_file
from colossalai.utils.safetensors import load_flat, save_nested from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
try: try:
from tensornvme.async_file_io import AsyncFileWriter from tensornvme.async_file_io import AsyncFileWriter
@ -11,17 +11,29 @@ except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
from colossalai.testing import check_state_dict_equal from colossalai.testing import check_state_dict_equal
from colossalai.utils import get_current_device
def test_save_load(): def test_save_load():
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = { optimizer_state_dict = {
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, "state": {
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, 0: {
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, "step": torch.tensor(1.0),
} "exp_avg": torch.rand((1024, 1024)),
# group_dict = {"param_groups": [0, 1, 2]} "exp_avg_sq": torch.rand((1024, 1024)),
group_dict = { },
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [ "param_groups": [
{ {
"lr": 0.001, "lr": 0.001,
@ -94,22 +106,26 @@ def test_save_load():
61, 61,
], ],
} }
] ],
} }
metadata = deepcopy(group_dict)
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict)
save_nested(f_writer, optimizer_state_dict, metadata)
f_writer.sync_before_step() f_writer.sync_before_step()
f_writer.synchronize() f_writer.synchronize()
f_writer.fp.close() f_writer.fp.close()
load_state_dict = load_flat(optimizer_saved_path) load_state_dict = load_flat(optimizer_saved_path)
state_dict = load_state_dict["state"] check_state_dict_equal(load_state_dict, optimizer_state_dict)
group = {"param_groups": load_state_dict["param_groups"]}
check_state_dict_equal(optimizer_state_dict, state_dict) optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
check_state_dict_equal(group_dict, group) f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
model_state_dict = { model_state_dict = {
"module.weight0": torch.rand((1024, 1024)), "module.weight0": torch.rand((1024, 1024)),
@ -118,10 +134,20 @@ def test_save_load():
} }
model_saved_path = f"{tempdir}/save_model.safetensors" model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, model_state_dict) save(f_writer, model_state_dict)
f_writer.sync_before_step() f_writer.sync_before_step()
f_writer.synchronize() f_writer.synchronize()
f_writer.fp.close() f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
load_state_dict = load_flat(model_saved_path) check_state_dict_equal(model_state_dict, load_state_dict)
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict) check_state_dict_equal(model_state_dict, load_state_dict)