mirror of https://github.com/hpcaitech/ColossalAI
[checkpointio] fix for async io (#6189)
parent
5ff5323538
commit
ce0ec40811
|
@ -315,12 +315,13 @@ def async_save_state_dict_shards(
|
||||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
|
|
||||||
if state_preprocess:
|
if state_preprocess:
|
||||||
state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".")
|
state_dict, metadata = _flatten_optim_state_dict(state_dict=shard, seperator=".")
|
||||||
else:
|
else:
|
||||||
state_dict = shard
|
state_dict = shard
|
||||||
|
metadata = None
|
||||||
|
|
||||||
# Only save on master rank.
|
# Only save on master rank.
|
||||||
writer = save(checkpoint_file_path, state_dict=state_dict)
|
writer = save(checkpoint_file_path, state_dict=state_dict, metadata=metadata)
|
||||||
writers.append(writer)
|
writers.append(writer)
|
||||||
shard_filenames.append(shard_file)
|
shard_filenames.append(shard_file)
|
||||||
del shard
|
del shard
|
||||||
|
@ -377,9 +378,10 @@ def async_move_save_state_dict_shards(
|
||||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||||
|
|
||||||
if state_preprocess:
|
if state_preprocess:
|
||||||
state_dict, _ = _flatten_optim_state_dict(state_dict=shard)
|
state_dict, metadata = _flatten_optim_state_dict(state_dict=shard)
|
||||||
else:
|
else:
|
||||||
state_dict = shard
|
state_dict = shard
|
||||||
|
metadata = None
|
||||||
|
|
||||||
if pinned_state_dict is not None:
|
if pinned_state_dict is not None:
|
||||||
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
|
sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()}
|
||||||
|
@ -388,7 +390,7 @@ def async_move_save_state_dict_shards(
|
||||||
returned_state_dict.update(sub_pinned_state_dict)
|
returned_state_dict.update(sub_pinned_state_dict)
|
||||||
|
|
||||||
# Only save on master rank.
|
# Only save on master rank.
|
||||||
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict)
|
writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict, metadata)
|
||||||
writers.append(writer)
|
writers.append(writer)
|
||||||
shard_filenames.append(shard_file)
|
shard_filenames.append(shard_file)
|
||||||
del shard
|
del shard
|
||||||
|
|
Loading…
Reference in New Issue