2023-04-06 08:23:39 +00:00
# coding=utf-8
2023-07-21 06:39:01 +00:00
import os
2023-05-18 12:05:59 +00:00
import re
2023-06-15 07:21:26 +00:00
from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
2023-04-04 07:23:01 +00:00
from pathlib import Path
2023-05-18 12:05:59 +00:00
from typing import Iterator , List , Mapping , Optional , OrderedDict , Tuple
2023-04-04 07:23:01 +00:00
import torch
2023-04-06 08:23:39 +00:00
import torch . nn as nn
2023-10-07 02:45:52 +00:00
from packaging . version import Version
2023-06-15 07:21:26 +00:00
from torch . optim import Optimizer
2023-11-09 09:00:25 +00:00
from torch . utils . _pytree import tree_map
2023-05-18 12:05:59 +00:00
2023-08-25 14:04:57 +00:00
from colossalai . tensor . d_tensor import (
is_customized_distributed_tensor ,
is_distributed_tensor ,
to_global ,
to_global_for_customized_distributed_tensor ,
)
2023-04-06 08:23:39 +00:00
SAFE_WEIGHTS_NAME = " model.safetensors "
2023-04-12 08:02:17 +00:00
WEIGHTS_NAME = " pytorch_model.bin "
2023-06-15 07:21:26 +00:00
STATES_NAME = " pytorch_optim.bin "
2023-04-06 08:23:39 +00:00
SAFE_WEIGHTS_INDEX_NAME = " model.safetensors.index.json "
2023-04-12 08:02:17 +00:00
WEIGHTS_INDEX_NAME = " pytorch_model.bin.index.json "
2023-06-15 07:21:26 +00:00
STATES_INDEX_NAME = " pytorch_optim.bin.index.json "
GROUP_FILE_NAME = " pytorch_optim_group.bin "
2023-04-04 07:23:01 +00:00
# ======================================
# General helper functions
# ======================================
2023-05-18 12:05:59 +00:00
2023-04-04 07:23:01 +00:00
def calculate_tensor_size ( tensor : torch . Tensor ) - > float :
"""
Calculate the size of a parameter in MB . Used to compute whether a group of params exceed the shard size .
If so , a new shard should be created .
Args :
2023-05-15 03:46:25 +00:00
tensor ( torch . Tensor ) : the tensor to calculate size for .
2023-04-04 07:23:01 +00:00
Returns :
float : size of the tensor in MB .
"""
return tensor . numel ( ) * tensor . element_size ( ) / 1024 / 1024
2023-05-18 12:05:59 +00:00
2023-04-04 07:23:01 +00:00
def is_safetensors_available ( ) - > bool :
"""
Check whether safetensors is available .
Returns :
bool : whether safetensors is available .
"""
try :
return True
except ImportError :
return False
def is_dtensor_checkpoint ( checkpoint_file_path : str ) - > bool :
"""
Check whether the checkpoint file is a dtensor checkpoint .
Args :
checkpoint_file_path ( str ) : path to the checkpoint file .
Returns :
bool : whether the checkpoint file is a dtensor checkpoint .
"""
2023-09-19 06:20:26 +00:00
if checkpoint_file_path . endswith ( " .*.safetensors " ) or checkpoint_file_path . endswith ( " .*.bin " ) :
2023-04-04 07:23:01 +00:00
return True
else :
return False
def is_safetensor_checkpoint ( checkpoint_file_path : str ) - > bool :
"""
Check whether the checkpoint file is a safetensor checkpoint .
Args :
checkpoint_file_path ( str ) : path to the checkpoint file .
Returns :
bool : whether the checkpoint file is a safetensor checkpoint .
"""
2023-09-19 06:20:26 +00:00
if checkpoint_file_path . endswith ( " .safetensors " ) :
2023-04-04 07:23:01 +00:00
return True
else :
return False
2023-08-31 06:50:47 +00:00
def search_tp_partition_dim ( current_shape : torch . Size , original_shape : torch . Size , tp_size : int ) - > Optional [ int ] :
2023-08-25 14:04:57 +00:00
"""
2023-08-31 06:50:47 +00:00
Given the current shape of parameter and the shape of parameter before sharding ,
return the dimension along which the parameter is sharded when using tensor parallel .
If tensor parallel is not used , return None .
2023-08-25 14:04:57 +00:00
Args :
2023-08-31 06:50:47 +00:00
current_shape ( torch . Size ) : The current shape of parameter after sharding .
original_shape ( torch . Size ) : The shape of parameter before sharding .
tp_size ( int ) : The size of tp group .
2023-08-25 14:04:57 +00:00
Returns :
2023-08-31 06:50:47 +00:00
Optional [ int ] : The dimension along which parameter is partitioned .
2023-08-25 14:04:57 +00:00
"""
2023-08-31 06:50:47 +00:00
partition_dim = None
for dim , length in enumerate ( original_shape ) :
if length > current_shape [ dim ] :
partition_dim = dim
break
if partition_dim is not None :
2023-09-19 06:20:26 +00:00
assert (
original_shape [ partition_dim ] == tp_size * current_shape [ partition_dim ]
) , f " The parameter isn ' t evenly distributed among tensor parallel group: \
2023-08-31 06:50:47 +00:00
shape before sharding { original_shape } , shape after sharding { current_shape } "
return partition_dim
2023-08-25 14:04:57 +00:00
[shardformer] refactor embedding resize (#5603)
* [branch rebase] rebase main to Feature/resize_embedding (#5554)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [CI] run pre-commit (#5577)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
* run pre-commit
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [rebase] rebase main to resize-embedding (#5581)
* [release] grok-1 314b inference (#5490)
* [release] grok-1 inference
* [release] grok-1 inference
* [release] grok-1 inference
* [example] update Grok-1 inference (#5495)
* revise grok-1 example
* remove unused arg in scripts
* prevent re-installing torch
* update readme
* revert modifying colossalai requirements
* add perf
* trivial
* add tokenizer url
* [hotfix] set return_outputs=False in examples and polish code (#5404)
* fix: simplify merge_batch
* fix: use return_outputs=False to eliminate extra memory consumption
* feat: add return_outputs warning
* style: remove `return_outputs=False` as it is the default value
* [release] grok-1 inference benchmark (#5500)
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [shardformer]Fix lm parallel. (#5480)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* fix lm forward distribution
* fix
* test ci
* fix
* [fix] fix grok-1 example typo (#5506)
* [devops] fix example test ci (#5504)
* Fix ColoTensorSpec for py11 (#5440)
* fixed layout converter caching and updated tester
* Empty-Commit
* [shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462)
* [extension] update api
* [feature] add colo attention
* [feature] update sdpa
* [feature] update npu attention
* [feature] update flash-attn
* [test] add flash attn test
* [test] update flash attn test
* [shardformer] update modeling to fit colo attention (#5465)
* [misc] refactor folder structure
* [shardformer] update llama flash-attn
* [shardformer] fix llama policy
* [devops] update tensornvme install
* [test] update llama test
* [shardformer] update colo attn kernel dispatch
* [shardformer] update blip2
* [shardformer] update chatglm
* [shardformer] update gpt2
* [shardformer] update gptj
* [shardformer] update opt
* [shardformer] update vit
* [shardformer] update colo attention mask prep
* [shardformer] update whisper
* [test] fix shardformer tests (#5514)
* [test] fix shardformer tests
* [test] fix shardformer tests
* [format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: github-actions <github-actions@github.com>
* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution
* Change static methods for t5 layer distribution to member functions
* Change static methods for whisper layer distribution to member functions
* Replace whisper policy usage with self one
* Fix test case to use non-static layer distribution methods
* fix: fix typo
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)
* [fix] use tokenizer from the same pretrained path
* trust remote code
* [ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all
* fix and tested ppo
* 2 nd round refactor
* add ci tests
* fix ci
* fix ci
* fix readme, style
* fix readme style
* fix style, fix benchmark
* reproduce benchmark result, remove useless files
* rename to ColossalChat
* use new image
* fix ci workflow
* fix ci
* use local model/tokenizer for ci tests
* fix ci
* fix ci
* fix ci
* fix ci timeout
* fix rm progress bar. fix ci timeout
* fix ci
* fix ci typo
* remove 3d plugin from ci temporary
* test environment
* cannot save optimizer
* support chat template
* fix readme
* fix path
* test ci locally
* restore build_or_pr
* fix ci data path
* fix benchmark
* fix ci, move ci tests to 3080, disable fast tokenizer
* move ci to 85
* support flash attention 2
* add all-in-one data preparation script. Fix colossal-llama2-chat chat template
* add hardware requirements
* move ci test data
* fix save_model, add unwrap
* fix missing bos
* fix missing bos; support grad accumulation with gemini
* fix ci
* fix ci
* fix ci
* fix llama2 chat template config
* debug sft
* debug sft
* fix colossalai version requirement
* fix ci
* add sanity check to prevent NaN loss
* fix requirements
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* update readme
* update readme
* update readme and ignore
* fix logger bug
* support parallel_output
* modify data preparation logic
* fix tokenization
* update lr
* fix inference
* run pre-commit
---------
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`
* feat: apply `GradientCheckpointConfig` to policy and llama_forward
* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager
* fix: add optional args for `distribute_layer` and `get_stage_index`
* fix: fix changed API calls
* test: update llama tests
* style: polish `GradientCheckpointConfig`
* fix: fix pipeline utils tests
* fix incorrect sharding without zero (#5545)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization
* validate sequence parallel in llama (code to be polished)
* shardformer api writing
* integrate sequence parallel in ShardFormer
* fix pp bugs and sp bugs for LlaMa model
* integrating ring-based sequence parallelism into ShardFormer
* [sequence parallelism]: Add fused megatron function
* integrating ring-based sequence parallelism into ShardFormer
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* fix bugs when useing sp and flashattention together
* fix operation function name
* support flash attention for ulysses-style sp
* clarify sp process group
* fix compatibility bugs in moe plugin
* fix fused linear bugs
* fix linear layer test
* support gpt model all-to-all sp
* modify shard data dimension (meant to be dim=-1)
* support megtron-style sp and distributed attn for llama model
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* finish sp mode 3 support for gpt
* using all_to_all_single when batch size is 1
* support mode 2 sp in gpt2 (#5)
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* refactor ring implementation
* support mode 2 sp in gpt2
* polish code
* enable distributed attn mask when using sp mode 2 and 3 in llama
* automatically enable flash attn when using sp mode 2 and 3 in llama
* inplace attn mask
* add zero2 support for sequence parallel
* polish code
* fix bugs
* fix gemini checkpoint io
* loose tensor checking atol and rtol
* add comment
* fix llama layernorm grad
* fix zero grad
* fix zero grad
* fix conflict
* update split and gather auto grad func
* sequence parallel: inside text split (#6)
* polish code (part 1)
* polish code (part 2)
* polish code (part 2.5)
* polish code (part 3)
* sequence parallel: inside text split
* miscellaneous minor fixes
* polish code
* fix ulysses style ZeRO
* sequence parallel: inside text split
* miscellaneous minor fixes
* disaggregate sp group and dp group for sp
* fix llama and gpt sp
* polish code
* move ulysses grad sync to ddp (#9)
* remove zero_stage and unbind the grad sync for alltoall sp
* add 2d group creation test
* move ulysses grad sync to ddp
* add 2d group creation test
* remove useless code
* change shard config not to enable sp when enable_all_optimizations
* add sp warnings for several model
* remove useless code
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* [hotfix] quick fixes to make legacy tutorials runnable (#5559)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [fix] fix typo s/muiti-node /multi-node etc. (#5448)
* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)
* [devops] remove post commit ci (#5566)
* [devops] remove post commit ci
* [misc] run pre-commit on all files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---------
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [shardformer]enable padding vocabulary size. (#5489)
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* padding vocab
* padding vocabe
* fix
* fix
* fxi
* test ci
* fix
fix
fix
fix
* fix
fix
* fix
* fix
* Update hybrid_parallel_plugin.py
fix
fix
fix
* fix
fix
* fix
fix
* fix
* resolve super init
resolve super init
resolve super init
resolve super init
* resolve comments
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* vocab checkpointio
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
fix
fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* padding vocab
* fix
* fix
fix
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* cherry-pick
* revert moe modify
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
fix
fix
fix
fix
fix
fix
fix
* resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
* ptensor
ptensor
resolve comments
fix
fix
fix
fix
fix
resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix rebase
* fix rebase
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-18 08:10:18 +00:00
def search_padding_dim ( global_shape : torch . Size , original_shape : torch . Size ) - > Optional [ int ] :
padding_dim = None
for dim , length in enumerate ( global_shape ) :
if length > original_shape [ dim ] :
padding_dim = dim
break
return padding_dim
2023-04-06 08:23:39 +00:00
# ======================================
2023-08-25 14:04:57 +00:00
# Helper classes and functions for saving shard file
2023-04-06 08:23:39 +00:00
# ======================================
2023-07-07 08:33:06 +00:00
2023-08-25 14:04:57 +00:00
class StateDictSharder :
def __init__ ( self , size_per_shard : int ) - > None :
self . max_shard_size = size_per_shard
self . current_block = OrderedDict ( )
self . current_block_size = 0
2023-08-31 06:50:47 +00:00
def append_param ( self , name : str , tensor : torch . Tensor ) - > Tuple [ Optional [ OrderedDict ] , int ] :
2023-08-25 14:04:57 +00:00
tensor_size = calculate_tensor_size ( tensor )
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self . current_block_size + tensor_size > self . max_shard_size and self . current_block_size > 0 :
ret_block = self . current_block
ret_block_size = self . current_block_size
self . current_block = OrderedDict ( )
self . current_block_size = 0
self . current_block [ name ] = tensor
self . current_block_size + = tensor_size
return ret_block , ret_block_size
2023-08-31 06:50:47 +00:00
def append_optim_state ( self , param_id : int , state : OrderedDict ) - > Tuple [ Optional [ OrderedDict ] , int ] :
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state . values ( ) :
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance ( state_tensor , torch . Tensor ) :
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor ( state_tensor ) :
isDTensor = True
state_size + = calculate_tensor_size ( state_tensor )
ret_block = None
ret_block_size = 0
# directly return if state is stored as distributed tensor
if isDTensor :
return ret_block , ret_block_size
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self . current_block_size + state_size > self . max_shard_size and self . current_block_size > 0 :
ret_block = self . current_block
ret_block_size = self . current_block_size
self . current_block = OrderedDict ( )
self . current_block_size = 0
self . current_block [ param_id ] = state
self . current_block_size + = state_size
return ret_block , ret_block_size
def gather_distributed_param ( param : torch . Tensor , keep_vars : bool = False ) - > torch . Tensor :
"""
Gather the complete parameter for saving if passed in param is distributed under tp setting .
Args :
param ( torch . Tensor ) : A model parameter , might be d_tensor .
keep_vars ( bool , optional ) : Whether to return the parameter in calculation graph . Defaults to False .
Returns :
torch . Tensor : the complete parameter
"""
param_ = param if keep_vars else param . detach ( )
if is_distributed_tensor ( param_ ) :
return to_global ( param_ )
elif is_customized_distributed_tensor ( param_ ) :
return to_global_for_customized_distributed_tensor ( param_ )
else :
return param_
2023-08-25 14:04:57 +00:00
2023-09-19 06:20:26 +00:00
def save_state_dict_shards (
sharded_state_dict : Iterator [ Tuple [ OrderedDict , int ] ] ,
checkpoint : str ,
index_file : " CheckpointIndexFile " ,
base_filename : str ,
is_master : bool ,
use_safetensors : bool = False ,
use_pp_format : bool = False ,
) - > int :
"""
2023-07-21 06:39:01 +00:00
Save sharded state dict only on master rank , this method can be used by both model and optimizer states .
Args :
sharded_state_dict ( Iterator [ Tuple [ OrderedDict , int ] ] ) : a generator of shards , each shard contains state dict and shard size .
checkpoint ( str ) : The path of checkpoint directory as string .
index_file ( CheckpointIndexFile ) : The index file object to be updated .
base_filename ( str ) : Decides the prefix of filenames of shards .
2023-09-01 09:40:01 +00:00
is_master ( bool ) : Whether current rank is main process .
use_safetensors ( bool , optional ) : Whether to use safetensors to save checkpoint . Defaults to False .
use_pp_format : ( bool , optional ) : Whether to save the files in pipeline format including stage information . Defaults to False .
2023-07-21 06:39:01 +00:00
Returns :
int : the total size of shards
2023-09-19 06:20:26 +00:00
"""
2023-07-21 06:39:01 +00:00
total_size = 0
2023-09-01 09:40:01 +00:00
shard_filenames = [ ]
2023-07-21 06:39:01 +00:00
for idx , shard_pair in enumerate ( sharded_state_dict ) :
2023-08-25 14:04:57 +00:00
shard , current_size = shard_pair
2023-07-21 06:39:01 +00:00
if not is_master :
2023-08-25 14:04:57 +00:00
del shard
2023-07-21 06:39:01 +00:00
continue
shard_file = get_shard_filename ( base_filename , idx )
total_size = total_size + current_size
for key in shard . keys ( ) :
index_file . append_weight_map ( key , shard_file )
checkpoint_file_path = os . path . join ( checkpoint , shard_file )
# Only save on master rank.
save_state_dict ( shard , checkpoint_file_path , use_safetensors = use_safetensors )
2023-09-01 09:40:01 +00:00
shard_filenames . append ( shard_file )
2023-08-25 14:04:57 +00:00
del shard
2023-07-21 06:39:01 +00:00
2023-09-01 09:40:01 +00:00
# Clean folder, deleted unneeded files.
clean_folder ( checkpoint , base_filename , shard_filenames , is_master = is_master , use_pp_format = use_pp_format )
2023-07-21 06:39:01 +00:00
return total_size
2023-06-15 07:21:26 +00:00
def shard_model_checkpoint ( state_dict : torch . Tensor , max_shard_size : int = 1024 ) - > Iterator [ Tuple [ OrderedDict , int ] ] :
2023-04-06 08:23:39 +00:00
"""
Splits a model state dictionary in sub - checkpoints so that the final size of each sub - checkpoint does not exceed a
given size .
"""
2023-08-31 06:50:47 +00:00
state_dict_sharder = StateDictSharder ( max_shard_size )
2023-04-06 08:23:39 +00:00
for key , weight in state_dict . items ( ) :
2023-06-26 07:50:07 +00:00
if not is_distributed_tensor ( weight ) :
2023-08-31 06:50:47 +00:00
block , block_size = state_dict_sharder . append_param ( key , weight )
2023-05-18 12:05:59 +00:00
2023-08-31 06:50:47 +00:00
if block != None :
yield block , block_size
2023-04-06 08:23:39 +00:00
2023-08-31 06:50:47 +00:00
# Return the last block in sharder.
yield state_dict_sharder . current_block , state_dict_sharder . current_block_size
2023-04-06 08:23:39 +00:00
2023-06-15 07:21:26 +00:00
def shard_optimizer_checkpoint ( state_dict : dict , max_shard_size : int = 1024 ) - > Iterator [ Tuple [ OrderedDict , int ] ] :
"""
Splits an optimizer state dictionary in sub - checkpoints so that the final size of each sub - checkpoint does not exceed a
given size .
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
2023-09-19 06:20:26 +00:00
states = state_dict [ " state " ]
2023-08-31 06:50:47 +00:00
state_dict_sharder = StateDictSharder ( max_shard_size )
2023-06-15 07:21:26 +00:00
for param_id , state in states . items ( ) :
2023-08-31 06:50:47 +00:00
block , block_size = state_dict_sharder . append_optim_state ( param_id , state )
if block != None :
yield block , block_size
2023-06-15 07:21:26 +00:00
2023-08-31 06:50:47 +00:00
# Return the last block in sharder.
yield state_dict_sharder . current_block , state_dict_sharder . current_block_size
2023-06-15 07:21:26 +00:00
2023-06-16 06:14:05 +00:00
2023-08-31 06:50:47 +00:00
# ======================================
# Helper functions for saving state dict
# ======================================
2023-06-16 06:14:05 +00:00
2024-04-08 07:09:40 +00:00
2023-08-31 06:50:47 +00:00
def save_state_dict ( state_dict : dict , checkpoint_file_path : str , use_safetensors : bool ) - > None :
"""
Save state dict to checkpoint .
Args :
state_dict ( dict ) : state dict .
checkpoint_file_path ( str ) : path to the checkpoint file .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
"""
2023-11-09 09:00:25 +00:00
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
state_dict_cpu = tree_map ( lambda x : x . cpu ( ) if torch . is_tensor ( x ) else x , state_dict )
2024-04-08 07:09:40 +00:00
2023-08-31 06:50:47 +00:00
if use_safetensors :
assert is_safetensors_available ( ) , " safetensors is not available. "
2023-09-19 06:20:26 +00:00
assert checkpoint_file_path . endswith (
" .safetensors "
) , " safetensors only supports .safetensors suffix for checkpoint file. "
2023-08-31 06:50:47 +00:00
from safetensors . torch import save_file as safe_save_file
2023-09-19 06:20:26 +00:00
2023-11-09 09:00:25 +00:00
safe_save_file ( state_dict_cpu , checkpoint_file_path , metadata = { " format " : " pt " } )
2023-08-31 06:50:47 +00:00
else :
2023-11-09 09:00:25 +00:00
torch . save ( state_dict_cpu , checkpoint_file_path )
2023-08-31 06:50:47 +00:00
def save_param_groups ( state_dict : dict , group_file_path : str ) - > None :
"""
Save information of param_groups to given file path .
Args :
state_dict ( dict ) : state dict .
group_file_path ( str ) : path to the group file .
"""
param_groups = state_dict [ " param_groups " ]
torch . save ( param_groups , group_file_path )
2023-09-19 06:20:26 +00:00
def clean_folder (
checkpoint_path : str ,
weights_name : str ,
shard_filenames : List [ str ] ,
is_master : bool = True ,
use_pp_format : bool = False ,
) :
2023-09-01 09:40:01 +00:00
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved .
Args :
checkpoint_path ( str ) : Path to the checkpoint directory .
weights_name ( str ) : Decides the prefix of filenames of weight shards .
shard_filenames ( List [ str ] ) : The list of saved shard filenames which should not be removed .
is_master ( bool , optional ) : Whether current rank is main process . Defaults to True .
use_pp_format : ( bool , optional ) : Whether to save the files in pipeline format including stage information . Defaults to False .
"""
if is_master :
for filename in os . listdir ( checkpoint_path ) :
full_filename = os . path . join ( checkpoint_path , filename )
weights_no_suffix = weights_name . replace ( " .bin " , " " ) . replace ( " .safetensors " , " " )
filename_no_suffix = filename . replace ( " .bin " , " " ) . replace ( " .safetensors " , " " )
if not use_pp_format :
reg = re . compile ( r " (.*?)- \ d {5} " )
else :
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re . compile ( r " (.*?)-stage- \ d {5} -shard- \ d {5} " )
2023-09-19 06:20:26 +00:00
if (
filename . startswith ( weights_no_suffix )
and os . path . isfile ( full_filename )
and filename not in shard_filenames
and reg . fullmatch ( filename_no_suffix ) is not None
) :
2023-09-01 09:40:01 +00:00
os . remove ( full_filename )
def save_config_file ( model : nn . Module , checkpoint_path : str , is_master : bool = True ) :
"""
Save config . json / generation_config . json if model is a Huggingface pretrained model .
This method can only be called when a model is saved in a sharded way .
Args :
model ( nn . Module ) : The model whose config should be saved if it ' s a huggingface model.
checkpoint_path ( str ) : Path to the checkpoint directory .
is_master ( bool ) : Whether current rank is main process .
"""
2023-09-11 08:24:28 +00:00
try :
from transformers . modeling_utils import PreTrainedModel , get_parameter_dtype
from transformers . modeling_utils import unwrap_model as unwrap_huggingface_model
except ImportError :
return
2023-09-01 09:40:01 +00:00
if not isinstance ( model , PreTrainedModel ) :
return
model = unwrap_huggingface_model ( model )
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
dtype = get_parameter_dtype ( model )
model . config . torch_dtype = str ( dtype ) . split ( " . " ) [ 1 ]
# Attach architecture to the config
model . config . architectures = [ model . __class__ . __name__ ]
# Save the config
if is_master :
model . config . save_pretrained ( checkpoint_path )
if model . can_generate ( ) :
model . generation_config . save_pretrained ( checkpoint_path )
2023-08-31 06:50:47 +00:00
def save_dtensor ( name : str , tensor : torch . Tensor , index_file : " CheckpointIndexFile " , use_safetensors : bool ) - > None :
"""
Save distributed tensor to checkpoint . This checkpoint will be a dictionary which contains
only one tensor .
Args :
tensor ( Tensor ) : tensor to be saved .
index_file ( CheckpointIndexFile ) : path to the checkpoint file .
size_per_shard ( int ) : size per shard in MB .
"""
root_path = index_file . root_path
2023-09-19 06:20:26 +00:00
output_root_path = root_path . joinpath ( " dtensor " )
2023-08-31 06:50:47 +00:00
# create directory
output_root_path . mkdir ( exist_ok = True )
# save tensor to this directory
# TODO(YuliangLiu): get index of the tensor shard
# e.g. index =
index = 0
# save tensor to file
ckpt_file_name = generate_dtensor_file_name ( name , index , use_safetensors )
ckpt_file_path = output_root_path . joinpath ( ckpt_file_name )
# dtensor ckpt file always contains only one tensor
state_dict = { name : tensor }
save_state_dict ( state_dict , str ( ckpt_file_path ) , use_safetensors )
# update the weight map
# * means all shards
2023-09-19 06:20:26 +00:00
ckpt_file_name_in_weight_map = " dtensor/ " + generate_dtensor_file_name ( name , " * " , use_safetensors )
2023-08-31 06:50:47 +00:00
index_file . append_weight_map ( name , ckpt_file_name_in_weight_map )
def get_checkpoint_file_suffix ( use_safetensors : bool ) - > str :
"""
Get checkpoint file suffix .
Args :
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
Returns :
str : checkpoint file suffix .
"""
if use_safetensors :
2023-09-19 06:20:26 +00:00
return " .safetensors "
2023-08-31 06:50:47 +00:00
else :
2023-09-19 06:20:26 +00:00
return " .bin "
2023-08-31 06:50:47 +00:00
2023-09-19 06:20:26 +00:00
def generate_checkpoint_shard_file_name (
index : int , total_number : int , use_safetensors : bool , prefix : str = None
) - > str :
2023-08-31 06:50:47 +00:00
"""
Generate checkpoint shard file name .
Args :
index ( int ) : index of the shard .
total_number ( int ) : total number of shards .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
prefix ( str ) : prefix of the shard file name . Default : None .
Returns :
str : checkpoint shard file name .
"""
suffix = get_checkpoint_file_suffix ( use_safetensors )
if prefix is None :
return f " { index : 05d } -of- { total_number : 05d } . { suffix } "
else :
return f " { prefix } - { index : 05d } -of- { total_number : 05d } . { suffix } "
2023-06-15 07:21:26 +00:00
2023-08-31 06:50:47 +00:00
def generate_dtensor_file_name ( param_name : str , index : int , use_safetensors : bool ) - > str :
"""
Generate dtensor file name .
Args :
param_name ( str ) : name of the distributed parameter .
index ( int ) : index of the shard .
use_safetensors ( bool ) : whether to use safetensors to save the checkpoint .
2023-06-15 07:21:26 +00:00
2023-08-31 06:50:47 +00:00
Returns :
str : dtensor file name .
"""
suffix = get_checkpoint_file_suffix ( use_safetensors )
2023-09-19 06:20:26 +00:00
return f " { param_name } . { index } . { suffix } "
2023-06-15 07:21:26 +00:00
2023-08-31 06:50:47 +00:00
# ========================================
# Helper functions for loading state dict
# ========================================
2023-06-15 07:21:26 +00:00
2023-05-18 12:05:59 +00:00
def load_shard_state_dict ( checkpoint_file : Path , use_safetensors : bool = False ) :
2023-04-06 08:23:39 +00:00
"""
load shard state dict into model
"""
if use_safetensors and not checkpoint_file . suffix == " .safetensors " :
raise Exception ( " load the model using `safetensors`, but no file endwith .safetensors " )
if use_safetensors :
from safetensors . torch import load_file as safe_load_file
2023-05-18 12:05:59 +00:00
from safetensors . torch import safe_open
2023-09-19 06:20:26 +00:00
2023-04-06 08:23:39 +00:00
with safe_open ( checkpoint_file , framework = " pt " ) as f :
metadata = f . metadata ( )
if metadata [ " format " ] != " pt " :
raise NotImplementedError (
2023-09-19 06:20:26 +00:00
f " Conversion from a { metadata [ ' format ' ] } safetensors archive to PyTorch is not implemented yet. "
)
2023-04-06 08:23:39 +00:00
return safe_load_file ( checkpoint_file )
else :
2023-09-19 06:20:26 +00:00
return torch . load ( checkpoint_file , map_location = torch . device ( " cpu " ) )
2023-05-18 12:05:59 +00:00
2023-09-19 06:20:26 +00:00
def load_state_dict_into_model (
model : nn . Module , state_dict : torch . Tensor , missing_keys : List , strict : bool = False , load_sub_module : bool = True
) :
2023-04-06 08:23:39 +00:00
r """ Copies parameters and buffers from :attr:`state_dict` into
2023-05-18 12:05:59 +00:00
this module and its descendants .
2023-04-06 08:23:39 +00:00
Args :
state_dict ( dict ) : a dict containing parameters and
persistent buffers .
"""
if not isinstance ( state_dict , Mapping ) :
raise TypeError ( " Expected state_dict to be dict-like, got {} . " . format ( type ( state_dict ) ) )
unexpected_keys : List [ str ] = [ ]
sub_missing_keys : List [ str ] = [ ]
error_msgs : List [ str ] = [ ]
# copy state_dict so _load_from_state_dict can modify it
2023-09-19 06:20:26 +00:00
metadata = getattr ( state_dict , " _metadata " , None )
2023-04-06 08:23:39 +00:00
state_dict = OrderedDict ( state_dict )
if metadata is not None :
state_dict . _metadata = metadata
2023-05-05 06:37:21 +00:00
def load ( module : nn . Module , state_dict , prefix = " " , load_sub_module : bool = True ) :
2023-04-06 08:23:39 +00:00
local_metadata = { } if metadata is None else metadata . get ( prefix [ : - 1 ] , { } )
2023-05-05 06:37:21 +00:00
args = ( state_dict , prefix , local_metadata , True , sub_missing_keys , [ ] , error_msgs )
2023-04-06 08:23:39 +00:00
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len ( [ key for key in state_dict if key . startswith ( prefix ) ] ) > 0 :
module . _load_from_state_dict ( * args )
2023-05-05 06:37:21 +00:00
if load_sub_module :
for name , child in module . _modules . items ( ) :
if child is not None :
load ( child , state_dict , prefix + name + " . " )
2023-04-06 08:23:39 +00:00
2023-05-05 06:37:21 +00:00
load ( model , state_dict , " " , load_sub_module )
2023-04-06 08:23:39 +00:00
del load
2023-05-05 06:37:21 +00:00
missing_keys = missing_keys . append ( sub_missing_keys )
2023-04-06 08:23:39 +00:00
if strict :
if len ( unexpected_keys ) > 0 :
2023-09-19 06:20:26 +00:00
error_msgs = " Unexpected key(s) in state_dict: {} . " . format (
" , " . join ( ' " {} " ' . format ( k ) for k in unexpected_keys )
)
raise RuntimeError (
" Error(s) in loading state_dict for {} : \n \t {} " . format ( model . __class__ . __name__ , " \n \t " . join ( error_msgs ) )
)
2023-05-18 12:05:59 +00:00
2023-06-15 07:21:26 +00:00
def load_param_groups_into_optimizer ( optimizer : Optimizer , param_group_path : str ) - > dict :
"""
Load information of param_groups into an initialized optimizer .
"""
# Load list of param_groups from given file path.
# The params in saved_groups are in the form of integer indices.
2023-09-19 06:20:26 +00:00
saved_groups = torch . load ( param_group_path , map_location = torch . device ( " cpu " ) )
2023-06-15 07:21:26 +00:00
if not isinstance ( saved_groups , List ) :
2023-09-19 06:20:26 +00:00
raise ValueError ( f " The param_groups saved at { param_group_path } is not of List type " )
2023-06-15 07:21:26 +00:00
# The params in param_groups are in the form of pytorch tensors.
# For more details, please view source code of Optimizer class in pytorch.
param_groups = optimizer . param_groups
# Check the compatibility of saved_groups and param_groups.
if len ( param_groups ) != len ( saved_groups ) :
raise ValueError ( " loaded state dict has a different number of original parameter groups " )
2023-09-19 06:20:26 +00:00
param_lens = ( len ( g [ " params " ] ) for g in param_groups )
saved_lens = ( len ( g [ " params " ] ) for g in saved_groups )
2023-06-15 07:21:26 +00:00
if any ( p_len != s_len for p_len , s_len in zip ( param_lens , saved_lens ) ) :
2023-09-19 06:20:26 +00:00
raise ValueError (
" loaded state dict contains a parameter group " " that doesn ' t match the size of optimizer ' s group "
)
2023-06-15 07:21:26 +00:00
# Creating mapping from id to parameters.
id_map = {
2023-09-19 06:20:26 +00:00
old_id : p
for old_id , p in zip (
chain . from_iterable ( ( g [ " params " ] for g in saved_groups ) ) ,
chain . from_iterable ( ( g [ " params " ] for g in param_groups ) ) ,
)
2023-06-15 07:21:26 +00:00
}
# Update parameter groups, setting their 'params' value.
def update_group ( group , new_group ) :
2023-09-19 06:20:26 +00:00
new_group [ " params " ] = group [ " params " ]
2023-06-15 07:21:26 +00:00
return new_group
updated_groups = [ update_group ( g , ng ) for g , ng in zip ( param_groups , saved_groups ) ]
2023-09-19 06:20:26 +00:00
optimizer . __dict__ . update ( { " param_groups " : updated_groups } )
2023-06-15 07:21:26 +00:00
return id_map
2023-08-31 06:50:47 +00:00
def load_states_into_optimizer ( optimizer : Optimizer , state_dict : dict , id_map : dict , strict : bool = False ) :
2023-06-15 07:21:26 +00:00
r """ Copies states from `state_dict` into an Optimizer object.
Args :
optimizer ( Optimizer ) : An initialized Optimizer object to be loaded
2023-08-31 06:50:47 +00:00
state_dict ( dict ) : A mapping from tensor index ( an integer )
2023-06-15 07:21:26 +00:00
to its states to be loaded ( a mapping from state name to a tensor ) .
2023-08-31 06:50:47 +00:00
id_map ( dict ) : A mapping from tensor index ( an integer )
2023-06-15 07:21:26 +00:00
to its corresponding parameter ( a tensor ) whose states will be updated .
2023-08-31 06:50:47 +00:00
strict ( bool , optional ) : If set to True , only load the parameters with its id in id_map . Defaults to False .
2023-06-15 07:21:26 +00:00
"""
2023-08-31 06:50:47 +00:00
# Ensure that the keys of state_dict are integers.
state_dict = { int ( k ) : v for k , v in state_dict . items ( ) }
2023-06-15 07:21:26 +00:00
def cast ( param , value , key = None ) :
r """ Make a deep copy of value, casting all tensors to device of param. """
if isinstance ( value , torch . Tensor ) :
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
# Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
2023-09-19 06:20:26 +00:00
if key != " step " :
2023-06-15 07:21:26 +00:00
if param . is_floating_point ( ) :
value = value . to ( param . dtype )
value = value . to ( param . device )
return value
elif isinstance ( value , dict ) :
return { k : cast ( param , v , key = k ) for k , v in value . items ( ) }
elif isinstance ( value , container_abcs . Iterable ) :
return type ( value ) ( cast ( param , v ) for v in value )
else :
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
new_states = defaultdict ( dict )
for k , v in state_dict . items ( ) :
if k in id_map :
param = id_map [ k ]
new_states [ param ] = cast ( param , v )
2023-08-31 06:50:47 +00:00
elif not strict :
2023-06-15 07:21:26 +00:00
new_states [ k ] = v
2023-06-16 06:14:05 +00:00
optimizer . state . update ( new_states )
2023-06-15 07:21:26 +00:00
def sharded_optimizer_loading_epilogue ( optimizer : Optimizer ) :
2023-06-16 06:14:05 +00:00
r """ Do the cleaning up work after state_dict has been loaded into optimizer
Args :
optimizer ( Optimizer ) : An optimizer object whose state has just been loaded .
"""
2023-06-15 07:21:26 +00:00
# Do the cleaning up as in src code of Pytorch.
2023-10-07 02:45:52 +00:00
if Version ( torch . __version__ ) > = Version ( " 2.0.0 " ) :
optimizer . _patch_step_function ( ) # To support multiprocessing pickle/unpickle
else :
optimizer . _hook_for_profile ( ) # To support multiprocessing pickle/unpickle.
2023-09-19 06:20:26 +00:00
optimizer . defaults . setdefault ( " differentiable " , False )
2023-06-15 07:21:26 +00:00
2023-04-04 07:23:01 +00:00
def has_index_file ( checkpoint_path : str ) - > Tuple [ bool , Optional [ Path ] ] :
"""
Check whether the checkpoint has an index file .
Args :
checkpoint_path ( str ) : path to the checkpoint .
Returns :
Tuple [ bool , Optional [ Path ] ] : a tuple of ( has_index_file , index_file_path )
"""
checkpoint_path = Path ( checkpoint_path )
if checkpoint_path . is_file ( ) :
# check if it is .index.json
2023-04-12 08:02:17 +00:00
reg = re . compile ( " (.*?).index(( \ ..*)?).json " )
if reg . fullmatch ( checkpoint_path . name ) is not None :
2023-04-04 07:23:01 +00:00
return True , checkpoint_path
else :
return False , None
elif checkpoint_path . is_dir ( ) :
# check if there is only one a file ending with .index.json in this directory
2023-09-19 06:20:26 +00:00
index_files = list ( checkpoint_path . glob ( " *.index.*json " ) )
2023-04-04 07:23:01 +00:00
# if we found a .index.json file, make sure there is only one
if len ( index_files ) > 0 :
2023-09-19 06:20:26 +00:00
assert (
len ( index_files ) == 1
) , f " Expected to find one .index.json file in { checkpoint_path } , but found { len ( index_files ) } "
2023-04-04 07:23:01 +00:00
if len ( index_files ) == 1 :
return True , index_files [ 0 ]
else :
return False , None
2023-05-18 12:05:59 +00:00
else :
2023-09-19 06:20:26 +00:00
raise RuntimeError ( f " Invalid checkpoint path { checkpoint_path } . Expected a file or a directory. " )
2023-04-04 07:23:01 +00:00
def load_state_dict ( checkpoint_file_path : Path ) :
"""
Load state dict from checkpoint .
Args :
checkpoint_file_path ( Path ) : path to the checkpoint file .
Returns :
dict : state dict .
"""
2023-09-19 06:20:26 +00:00
assert not is_dtensor_checkpoint (
checkpoint_file_path
) , f " Cannot load state dict from dtensor checkpoint { checkpoint_file_path } , you should convert the distributed tensors to gathered tensors with our CLI offline. "
2023-04-04 07:23:01 +00:00
if is_safetensor_checkpoint ( checkpoint_file_path ) :
2023-09-19 06:20:26 +00:00
assert (
is_safetensors_available ( )
) , f " Cannot load state dict from safetensor checkpoint { checkpoint_file_path } , because safetensors is not available. Please install safetensors first with pip install safetensors. "
2023-04-04 07:23:01 +00:00
# load with safetensors
from safetensors import safe_open
2023-09-19 06:20:26 +00:00
2023-04-04 07:23:01 +00:00
state_dict = { }
with safe_open ( checkpoint_file_path , framework = " pt " , device = " cpu " ) as f :
for k in f . keys ( ) :
state_dict [ k ] = f . get_tensor ( k )
return state_dict
else :
# load with torch
2023-09-19 06:20:26 +00:00
return torch . load ( checkpoint_file_path , map_location = torch . device ( " cpu " ) )
2023-04-12 08:02:17 +00:00
2023-06-15 07:21:26 +00:00
def add_prefix ( weights_name : str , prefix : Optional [ str ] = None ) - > str :
if prefix is not None and len ( prefix ) > 0 :
2023-04-12 08:02:17 +00:00
splits = weights_name . split ( " . " )
2023-06-15 07:21:26 +00:00
splits = splits [ : - 1 ] + [ prefix ] + splits [ - 1 : ]
2023-04-12 08:02:17 +00:00
weights_name = " . " . join ( splits )
return weights_name
2023-05-05 06:37:21 +00:00
2023-06-15 07:21:26 +00:00
def get_model_base_filenames ( prefix : str = None , use_safetensors : bool = False ) :
2023-05-18 12:05:59 +00:00
"""
2023-06-15 07:21:26 +00:00
generate base model weight filenames
2023-05-18 12:05:59 +00:00
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
2023-06-15 07:21:26 +00:00
weights_name = add_prefix ( weights_name , prefix )
2023-05-18 12:05:59 +00:00
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
2023-06-15 07:21:26 +00:00
save_index_file = add_prefix ( save_index_file , prefix )
2023-05-05 06:37:21 +00:00
2023-05-18 12:05:59 +00:00
return weights_name , save_index_file
2023-05-05 06:37:21 +00:00
2023-06-15 07:21:26 +00:00
def get_optimizer_base_filenames ( prefix : str = None ) :
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = add_prefix ( states_name , prefix )
save_index_file = STATES_INDEX_NAME
save_index_file = add_prefix ( save_index_file , prefix )
param_group_file = GROUP_FILE_NAME
param_group_file = add_prefix ( param_group_file , prefix )
return states_name , save_index_file , param_group_file
2023-05-05 06:37:21 +00:00
def get_shard_filename ( weights_name : str , idx : int ) :
"""
get shard file name
"""
shard_file = weights_name . replace ( " .bin " , f " - { idx + 1 : 05d } .bin " )
2023-09-01 09:40:01 +00:00
shard_file = shard_file . replace ( " .safetensors " , f " - { idx + 1 : 05d } .safetensors " )
2023-05-18 12:05:59 +00:00
return shard_file