2023-07-21 06:39:01 +00:00
import gc
2023-05-05 11:36:10 +00:00
import logging
import os
2023-12-08 03:10:51 +00:00
import random
2023-05-05 06:37:21 +00:00
from pathlib import Path
2023-10-31 07:19:37 +00:00
from typing import Callable , Dict , Iterator , List , Optional , Tuple
2023-03-31 08:06:13 +00:00
2023-12-08 03:10:51 +00:00
import numpy as np
2023-03-31 08:06:13 +00:00
import torch
2023-11-10 02:15:16 +00:00
import torch . distributed as dist
2023-03-31 08:06:13 +00:00
import torch . nn as nn
2023-11-20 11:46:43 +00:00
from torch . distributed . distributed_c10d import _get_default_group
2023-03-31 08:06:13 +00:00
from torch . optim import Optimizer
from torch . optim . lr_scheduler import _LRScheduler as LRScheduler
from torch . utils . data import DataLoader
2023-12-08 03:10:51 +00:00
from torch . utils . data . distributed import DistributedSampler
2023-03-31 08:06:13 +00:00
2024-01-09 06:57:07 +00:00
from colossalai . accelerator import get_accelerator
2023-05-05 11:36:10 +00:00
from colossalai . checkpoint_io import CheckpointIndexFile , CheckpointIO , GeneralCheckpointIO
2023-07-21 06:39:01 +00:00
from colossalai . checkpoint_io . utils import (
get_model_base_filenames ,
get_optimizer_base_filenames ,
load_shard_state_dict ,
2023-09-04 15:25:01 +00:00
save_config_file ,
2023-07-21 06:39:01 +00:00
save_state_dict ,
save_state_dict_shards ,
)
2023-11-10 02:15:16 +00:00
from colossalai . cluster import DistCoordinator , ProcessGroupMesh
2023-03-31 08:06:13 +00:00
from colossalai . interface import ModelWrapper , OptimizerWrapper
2023-11-10 02:15:16 +00:00
from colossalai . shardformer import ShardConfig , ShardFormer
2023-08-24 01:29:25 +00:00
from colossalai . zero import GeminiDDP , GeminiOptimizer
2023-04-04 05:48:16 +00:00
from colossalai . zero . gemini . memory_tracer import MemStats
2023-03-31 08:06:13 +00:00
2023-05-05 11:36:10 +00:00
from . dp_plugin_base import DPPluginBase
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
__all__ = [ " GeminiPlugin " ]
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
SUPPORTED_PRECISION = [ " fp16 " , " bf16 " ]
PRECISION_STR_TO_DTYPE = { " fp16 " : torch . half , " bf16 " : torch . bfloat16 }
2023-06-05 07:58:31 +00:00
2023-11-16 13:03:04 +00:00
ZERO_AXIS , DP_AXIS , TP_AXIS = 0 , 1 , 2
2023-03-31 08:06:13 +00:00
2023-11-20 08:12:41 +00:00
2023-11-10 02:15:16 +00:00
def get_param_info ( optim : Optimizer ) :
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.
if optim is None :
return { }
param_info = { " id2shape " : { } }
[shardformer] refactor embedding resize (#5603)
* [branch rebase] rebase main to Feature/resize_embedding (#5554)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [CI] run pre-commit (#5577)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
* run pre-commit
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [rebase] rebase main to resize-embedding (#5581)
* [release] grok-1 314b inference (#5490)
* [release] grok-1 inference
* [release] grok-1 inference
* [release] grok-1 inference
* [example] update Grok-1 inference (#5495)
* revise grok-1 example
* remove unused arg in scripts
* prevent re-installing torch
* update readme
* revert modifying colossalai requirements
* add perf
* trivial
* add tokenizer url
* [hotfix] set return_outputs=False in examples and polish code (#5404)
* fix: simplify merge_batch
* fix: use return_outputs=False to eliminate extra memory consumption
* feat: add return_outputs warning
* style: remove `return_outputs=False` as it is the default value
* [release] grok-1 inference benchmark (#5500)
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [shardformer]Fix lm parallel. (#5480)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* fix lm forward distribution
* fix
* test ci
* fix
* [fix] fix grok-1 example typo (#5506)
* [devops] fix example test ci (#5504)
* Fix ColoTensorSpec for py11 (#5440)
* fixed layout converter caching and updated tester
* Empty-Commit
* [shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462)
* [extension] update api
* [feature] add colo attention
* [feature] update sdpa
* [feature] update npu attention
* [feature] update flash-attn
* [test] add flash attn test
* [test] update flash attn test
* [shardformer] update modeling to fit colo attention (#5465)
* [misc] refactor folder structure
* [shardformer] update llama flash-attn
* [shardformer] fix llama policy
* [devops] update tensornvme install
* [test] update llama test
* [shardformer] update colo attn kernel dispatch
* [shardformer] update blip2
* [shardformer] update chatglm
* [shardformer] update gpt2
* [shardformer] update gptj
* [shardformer] update opt
* [shardformer] update vit
* [shardformer] update colo attention mask prep
* [shardformer] update whisper
* [test] fix shardformer tests (#5514)
* [test] fix shardformer tests
* [test] fix shardformer tests
* [format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: github-actions <github-actions@github.com>
* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution
* Change static methods for t5 layer distribution to member functions
* Change static methods for whisper layer distribution to member functions
* Replace whisper policy usage with self one
* Fix test case to use non-static layer distribution methods
* fix: fix typo
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)
* [fix] use tokenizer from the same pretrained path
* trust remote code
* [ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all
* fix and tested ppo
* 2 nd round refactor
* add ci tests
* fix ci
* fix ci
* fix readme, style
* fix readme style
* fix style, fix benchmark
* reproduce benchmark result, remove useless files
* rename to ColossalChat
* use new image
* fix ci workflow
* fix ci
* use local model/tokenizer for ci tests
* fix ci
* fix ci
* fix ci
* fix ci timeout
* fix rm progress bar. fix ci timeout
* fix ci
* fix ci typo
* remove 3d plugin from ci temporary
* test environment
* cannot save optimizer
* support chat template
* fix readme
* fix path
* test ci locally
* restore build_or_pr
* fix ci data path
* fix benchmark
* fix ci, move ci tests to 3080, disable fast tokenizer
* move ci to 85
* support flash attention 2
* add all-in-one data preparation script. Fix colossal-llama2-chat chat template
* add hardware requirements
* move ci test data
* fix save_model, add unwrap
* fix missing bos
* fix missing bos; support grad accumulation with gemini
* fix ci
* fix ci
* fix ci
* fix llama2 chat template config
* debug sft
* debug sft
* fix colossalai version requirement
* fix ci
* add sanity check to prevent NaN loss
* fix requirements
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* update readme
* update readme
* update readme and ignore
* fix logger bug
* support parallel_output
* modify data preparation logic
* fix tokenization
* update lr
* fix inference
* run pre-commit
---------
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`
* feat: apply `GradientCheckpointConfig` to policy and llama_forward
* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager
* fix: add optional args for `distribute_layer` and `get_stage_index`
* fix: fix changed API calls
* test: update llama tests
* style: polish `GradientCheckpointConfig`
* fix: fix pipeline utils tests
* fix incorrect sharding without zero (#5545)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization
* validate sequence parallel in llama (code to be polished)
* shardformer api writing
* integrate sequence parallel in ShardFormer
* fix pp bugs and sp bugs for LlaMa model
* integrating ring-based sequence parallelism into ShardFormer
* [sequence parallelism]: Add fused megatron function
* integrating ring-based sequence parallelism into ShardFormer
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* fix bugs when useing sp and flashattention together
* fix operation function name
* support flash attention for ulysses-style sp
* clarify sp process group
* fix compatibility bugs in moe plugin
* fix fused linear bugs
* fix linear layer test
* support gpt model all-to-all sp
* modify shard data dimension (meant to be dim=-1)
* support megtron-style sp and distributed attn for llama model
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* finish sp mode 3 support for gpt
* using all_to_all_single when batch size is 1
* support mode 2 sp in gpt2 (#5)
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* refactor ring implementation
* support mode 2 sp in gpt2
* polish code
* enable distributed attn mask when using sp mode 2 and 3 in llama
* automatically enable flash attn when using sp mode 2 and 3 in llama
* inplace attn mask
* add zero2 support for sequence parallel
* polish code
* fix bugs
* fix gemini checkpoint io
* loose tensor checking atol and rtol
* add comment
* fix llama layernorm grad
* fix zero grad
* fix zero grad
* fix conflict
* update split and gather auto grad func
* sequence parallel: inside text split (#6)
* polish code (part 1)
* polish code (part 2)
* polish code (part 2.5)
* polish code (part 3)
* sequence parallel: inside text split
* miscellaneous minor fixes
* polish code
* fix ulysses style ZeRO
* sequence parallel: inside text split
* miscellaneous minor fixes
* disaggregate sp group and dp group for sp
* fix llama and gpt sp
* polish code
* move ulysses grad sync to ddp (#9)
* remove zero_stage and unbind the grad sync for alltoall sp
* add 2d group creation test
* move ulysses grad sync to ddp
* add 2d group creation test
* remove useless code
* change shard config not to enable sp when enable_all_optimizations
* add sp warnings for several model
* remove useless code
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* [hotfix] quick fixes to make legacy tutorials runnable (#5559)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [fix] fix typo s/muiti-node /multi-node etc. (#5448)
* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)
* [devops] remove post commit ci (#5566)
* [devops] remove post commit ci
* [misc] run pre-commit on all files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---------
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [shardformer]enable padding vocabulary size. (#5489)
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* padding vocab
* padding vocabe
* fix
* fix
* fxi
* test ci
* fix
fix
fix
fix
* fix
fix
* fix
* fix
* Update hybrid_parallel_plugin.py
fix
fix
fix
* fix
fix
* fix
fix
* fix
* resolve super init
resolve super init
resolve super init
resolve super init
* resolve comments
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* vocab checkpointio
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
fix
fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* padding vocab
* fix
* fix
fix
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* cherry-pick
* revert moe modify
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
fix
fix
fix
fix
fix
fix
fix
* resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
* ptensor
ptensor
resolve comments
fix
fix
fix
fix
fix
resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix rebase
* fix rebase
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-18 08:10:18 +00:00
2023-11-10 02:15:16 +00:00
start_index = 0
for group in optim . param_groups :
for param_id , param in enumerate ( group [ " params " ] , start_index ) :
original_shape = param . shape if isinstance ( param , torch . Tensor ) else None
param_info [ " id2shape " ] [ param_id ] = original_shape
start_index + = len ( group [ " params " ] )
return param_info
2023-11-20 08:12:41 +00:00
2023-03-31 08:06:13 +00:00
class GeminiCheckpointIO ( GeneralCheckpointIO ) :
def __init__ ( self ) - > None :
super ( ) . __init__ ( )
self . coordinator = DistCoordinator ( )
2023-04-06 01:43:51 +00:00
def save_unsharded_model ( self , model : GeminiDDP , checkpoint : str , gather_dtensor : bool , use_safetensors : bool ) :
2023-03-31 08:06:13 +00:00
"""
2023-07-07 08:33:06 +00:00
Save sharded model to checkpoint but only on master process .
The model should be unwrapped in self . load_model via ModelWrapper . unwrap .
2023-07-21 06:39:01 +00:00
As there is communication when getting state dict , model . state_dict ( ) must be called on all processes .
2023-03-31 08:06:13 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( model , GeminiDDP ) , " Please boost the model before saving! "
2023-03-31 08:06:13 +00:00
state_dict = model . state_dict ( only_rank_0 = True )
if self . coordinator . is_master ( ) :
2023-04-06 01:43:51 +00:00
save_state_dict ( state_dict , checkpoint , use_safetensors )
2023-03-31 08:06:13 +00:00
2023-07-07 08:33:06 +00:00
def load_unsharded_model ( self , model : GeminiDDP , checkpoint : str , strict : bool = True ) :
2023-03-31 08:06:13 +00:00
"""
2023-07-07 08:33:06 +00:00
Load model from checkpoint with automatic unwrapping .
The model should be unwrapped in self . load_model via ModelWrapper . unwrap .
2023-03-31 08:06:13 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( model , GeminiDDP ) , " Please boost the model before loading! "
2023-07-07 08:33:06 +00:00
super ( ) . load_unsharded_model ( model , checkpoint , strict = strict )
2023-05-19 11:42:31 +00:00
2023-09-20 10:29:37 +00:00
def save_unsharded_optimizer ( self , optimizer : GeminiOptimizer , checkpoint : str , gather_dtensor : bool ) :
2023-03-31 08:06:13 +00:00
"""
2023-07-07 08:33:06 +00:00
Save unsharded optimizer state dict to checkpoint .
After calling optimizer . state_dict ( ) , the complete optimizer states will be collected on master rank .
2023-07-21 06:39:01 +00:00
As there is communication when getting state dict , optimizer . state_dict ( ) must be called on all processes .
2023-07-07 08:33:06 +00:00
The saving process will only be executed by master rank .
2023-03-31 08:06:13 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( optimizer , GeminiOptimizer ) , " Please boost the optimizer before saving! "
2023-07-07 08:33:06 +00:00
state_dict = optimizer . state_dict ( )
2023-03-31 08:06:13 +00:00
if self . coordinator . is_master ( ) :
2023-07-07 08:33:06 +00:00
save_state_dict ( state_dict , checkpoint , use_safetensors = False )
2023-09-20 10:29:37 +00:00
def load_unsharded_optimizer ( self , optimizer : GeminiOptimizer , checkpoint : str ) :
2023-07-07 08:33:06 +00:00
"""
Loading unsharded optimizer from checkpoint file .
For each process , only loading optimizer states of parameters it controls .
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( optimizer , GeminiOptimizer ) , " Please boost the optimizer before loading! "
2023-07-07 08:33:06 +00:00
super ( ) . load_unsharded_optimizer ( optimizer , checkpoint )
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
def save_sharded_model (
self ,
model : GeminiDDP ,
checkpoint_path : str ,
gather_dtensor : bool = False ,
prefix : Optional [ str ] = None ,
max_shard_size : int = 1024 ,
use_safetensors : bool = False ,
) :
2023-05-05 06:37:21 +00:00
"""
2023-07-21 06:39:01 +00:00
Save sharded model .
As there is communication when getting state dict , model . state_dict ( ) must be called on all processes .
2023-05-05 06:37:21 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( model , GeminiDDP ) , " Please boost the model before saving! "
2023-07-07 08:33:06 +00:00
if os . path . isfile ( checkpoint_path ) :
logging . error ( f " Provided path ( { checkpoint_path } ) should be a directory, not a file " )
return
Path ( checkpoint_path ) . mkdir ( parents = True , exist_ok = True )
2023-10-12 02:39:08 +00:00
state_dict_shard = model . state_dict_shard ( max_shard_size = max_shard_size , only_rank_0 = True )
2023-06-15 07:21:26 +00:00
weights_name , save_index_file = get_model_base_filenames ( prefix , use_safetensors )
2023-05-05 06:37:21 +00:00
index_file = CheckpointIndexFile ( checkpoint_path )
2023-05-05 11:36:10 +00:00
2023-07-21 06:39:01 +00:00
# Save shards of optimizer states.
is_master = self . coordinator . is_master ( )
2023-09-19 06:20:26 +00:00
total_size = save_state_dict_shards (
sharded_state_dict = state_dict_shard ,
checkpoint = checkpoint_path ,
index_file = index_file ,
base_filename = weights_name ,
is_master = is_master ,
use_safetensors = use_safetensors ,
)
2023-06-09 01:48:49 +00:00
# only save the index file on the master rank
if self . coordinator . is_master ( ) :
2023-07-21 06:39:01 +00:00
index_file . append_meta_data ( " total_size " , total_size )
2023-06-09 01:48:49 +00:00
index_file . write_index_file ( save_index_file )
2023-09-20 10:29:37 +00:00
save_config_file ( model . unwrap ( ) , checkpoint_path )
2023-09-19 06:20:26 +00:00
logging . info (
f " The model is split into checkpoint shards. "
f " You can find where each parameters has been saved in the "
f " index located at { save_index_file } . "
)
def load_sharded_model (
self , model : GeminiDDP , checkpoint_index_file : Path , strict : bool = False , use_safetensors : bool = False
) :
2023-05-05 06:37:21 +00:00
"""
2023-07-21 06:39:01 +00:00
Load shard model , load model from multiple files .
2023-05-05 06:37:21 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( model , GeminiDDP ) , " Please boost the model before loading! "
2023-05-05 06:37:21 +00:00
return super ( ) . load_sharded_model ( model , checkpoint_index_file , strict , use_safetensors , load_sub_module = False )
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
def save_sharded_optimizer (
2023-09-20 10:29:37 +00:00
self , optimizer : GeminiOptimizer , checkpoint : Path , gather_dtensor : bool , prefix : str , size_per_shard : int
2023-09-19 06:20:26 +00:00
) :
2023-07-07 08:33:06 +00:00
"""
Save sharded optimizer state dict to checkpoint folder .
As there is communication when getting state dict , this must be called on all processes .
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( optimizer , GeminiOptimizer ) , " Please boost the optimizer before saving! "
2023-07-21 06:39:01 +00:00
if os . path . isfile ( checkpoint ) :
logging . error ( f " Provided path ( { checkpoint } ) should be a directory, not a file " )
return
2023-07-07 08:33:06 +00:00
Path ( checkpoint ) . mkdir ( parents = True , exist_ok = True )
2023-07-21 06:39:01 +00:00
# Preparing file paths and index file.
states_name , save_index_file , param_group_file = get_optimizer_base_filenames ( prefix )
index_file = CheckpointIndexFile ( checkpoint )
2023-10-31 06:48:01 +00:00
index_file . append_meta_data ( " param_groups " , param_group_file )
2023-07-21 06:39:01 +00:00
# Store the information of param groups to param_group_file.
2023-10-31 06:48:01 +00:00
if self . coordinator . is_master ( ) :
group_file_path = os . path . join ( checkpoint , param_group_file )
param_groups = optimizer . get_param_groups_for_saving ( )
torch . save ( param_groups , group_file_path )
2023-07-21 06:39:01 +00:00
# States are broken into shards within max_shard_size.
state_dict_shard = optimizer . state_shard ( prefix = prefix , max_shard_size = size_per_shard , only_rank_0 = True )
# Save shards of optimizer states.
2023-09-19 06:20:26 +00:00
total_size = save_state_dict_shards (
sharded_state_dict = state_dict_shard ,
checkpoint = checkpoint ,
index_file = index_file ,
base_filename = states_name ,
2023-10-31 06:48:01 +00:00
is_master = self . coordinator . is_master ( ) ,
2023-09-19 06:20:26 +00:00
use_safetensors = False ,
)
2023-07-21 06:39:01 +00:00
# Wrap up index file. Only save it on master rank.
if self . coordinator . is_master ( ) :
index_file . append_meta_data ( " total_size " , total_size )
index_file . write_index_file ( save_index_file )
2023-09-19 06:20:26 +00:00
logging . info (
f " The optimizer is going to be split to checkpoint shards. "
f " You can find where each parameters has been saved in the "
f " index located at { save_index_file } . "
)
2023-07-07 08:33:06 +00:00
2023-09-20 10:29:37 +00:00
def load_sharded_optimizer ( self , optimizer : GeminiOptimizer , checkpoint_index_file : Path , prefix : str ) :
2023-07-07 08:33:06 +00:00
"""
Loading sharded optimizer from checkpoint folder , with index file given .
For each process , only loading optimizer states of parameters it controls .
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( optimizer , GeminiOptimizer ) , " Please boost the optimizer before loading! "
2023-07-21 06:39:01 +00:00
if not os . path . isfile ( checkpoint_index_file ) :
logging . error ( f " Provided path ( { checkpoint_index_file } ) should be a file " )
2023-08-24 01:29:25 +00:00
assert isinstance ( optimizer , GeminiOptimizer )
2023-07-21 06:39:01 +00:00
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile . from_file ( checkpoint_index_file )
# Load param_groups.
param_group_path = ckpt_index_file . get_param_group_filename ( )
if param_group_path is None :
2023-09-19 06:20:26 +00:00
raise RuntimeError (
f " Invalid index file path { checkpoint_index_file } for an optimizer. \
Lacking param group file under current directory . "
)
2023-07-21 06:39:01 +00:00
saved_param_groups = torch . load ( param_group_path )
optimizer . load_param_groups ( saved_param_groups )
checkpoint_files , _ = ckpt_index_file . get_checkpoint_filenames ( )
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files :
state_dict_shard = load_shard_state_dict ( Path ( shard_file ) , use_safetensors = False )
optimizer . load_param_states ( state_dict_shard )
del state_dict_shard
gc . collect ( )
optimizer . optimizer_loading_epilogue ( )
def save_lr_scheduler ( self , lr_scheduler : LRScheduler , checkpoint : str ) :
"""
Save model to checkpoint but only on master process .
"""
if self . coordinator . is_master ( ) :
super ( ) . save_lr_scheduler ( lr_scheduler , checkpoint )
2023-07-07 08:33:06 +00:00
2023-05-05 11:36:10 +00:00
class GeminiPlugin ( DPPluginBase ) :
2023-03-31 08:06:13 +00:00
"""
Plugin for Gemini .
2023-09-26 02:57:47 +00:00
` ` ` python
from colossalai . booster import Booster
from colossalai . booster . plugin import GeminiPlugin
model , train_dataset , optimizer , criterion = . . .
plugin = GeminiPlugin ( )
train_dataloader = plugin . prepare_dataloader ( train_dataset , batch_size = 8 )
booster = Booster ( plugin = plugin )
model , optimizer , train_dataloader , criterion = booster . boost ( model , optimizer , train_dataloader , criterion )
` ` `
2023-03-31 08:06:13 +00:00
Args :
2023-08-24 01:29:25 +00:00
chunk_config_dict ( dict , optional ) : chunk configuration dictionary .
chunk_init_device ( torch . device , optional ) : device to initialize the chunk .
placement_policy ( str , optional ) : " static " and " auto " . Defaults to " static " .
2023-10-17 06:07:21 +00:00
enable_gradient_accumulation ( bool , optional ) : Whether to enable gradient accumulation . When set to True , gradient will be stored after doing backward pass . Defaults to False .
2023-08-24 01:29:25 +00:00
shard_param_frac ( float , optional ) : fraction of parameters to be sharded . Only for " static " placement .
If ` shard_param_frac ` is 1.0 , it ' s equal to zero-3. If `shard_param_frac` is 0.0, it ' s equal to zero - 2. Defaults to 1.0 .
offload_optim_frac ( float , optional ) : fraction of optimizer states to be offloaded . Only for " static " placement .
If ` shard_param_frac ` is 1.0 and ` offload_optim_frac ` is 0.0 , it ' s equal to old " cuda " placement. Defaults to 0.0.
offload_param_frac ( float , optional ) : fraction of parameters to be offloaded . Only for " static " placement .
For efficiency , this argument is useful only when ` shard_param_frac ` is 1.0 and ` offload_optim_frac ` is 1.0 .
If ` shard_param_frac ` is 1.0 , ` offload_optim_frac ` is 1.0 and ` offload_param_frac ` is 1.0 , it ' s equal to old " cpu " placement.
When using static placement , we recommend users to tune ` shard_param_frac ` first and then ` offload_optim_frac ` .
Defaults to 0.0 .
warmup_non_model_data_ratio ( float , optional ) : ratio of expected non - model data memory during warmup . Only for " auto " placement . Defaults to 0.8 .
steady_cuda_cap_ratio ( float , optional ) : ratio of allowed cuda capacity for model data during steady state . Only for " auto " placement . Defaults to 0.9 .
2023-06-05 07:58:31 +00:00
precision ( str , optional ) : precision . Support ' fp16 ' and ' bf16 ' . Defaults to ' fp16 ' .
2023-10-17 06:07:21 +00:00
master_weights ( bool , optional ) : Whether to keep fp32 master parameter weights in optimizer . Defaults to True .
2023-03-31 08:06:13 +00:00
pin_memory ( bool , optional ) : use pin memory on CPU . Defaults to False .
force_outputs_fp32 ( bool , optional ) : force outputs are fp32 . Defaults to False .
strict_ddp_mode ( bool , optional ) : use strict ddp mode ( only use dp without other parallelism ) . Defaults to False .
2023-06-25 05:34:15 +00:00
search_range_m ( int , optional ) : chunk size searching range divided by 2 ^ 20. Defaults to 32.
2023-03-31 08:06:13 +00:00
hidden_dim ( int , optional ) : the hidden dimension of DNN .
Users can provide this argument to speed up searching .
If users do not know this argument before training , it is ok . We will use a default value 1024.
2023-06-25 05:34:15 +00:00
min_chunk_size_m ( float , optional ) : the minimum chunk size divided by 2 ^ 20.
2023-05-24 01:01:50 +00:00
If the aggregate size of parameters is still smaller than the minimum chunk size ,
2023-03-31 08:06:13 +00:00
all parameters will be compacted into one small chunk .
memstats ( MemStats , optional ) the memory statistics collector by a runtime memory tracer .
gpu_margin_mem_ratio ( float , optional ) : The ratio of GPU remaining memory ( after the first forward - backward )
which will be used when using hybrid CPU optimizer .
This argument is meaningless when ` placement_policy ` of ` GeminiManager ` is not " auto " .
Defaults to 0.0 .
2023-07-07 08:33:06 +00:00
initial_scale ( float , optional ) : Initial scale used by DynamicGradScaler . Defaults to 2 * * 16.
2023-03-31 08:06:13 +00:00
min_scale ( float , optional ) : Min scale used by DynamicGradScaler . Defaults to 1.
growth_factor ( float , optional ) : growth_factor used by DynamicGradScaler . Defaults to 2.
backoff_factor ( float , optional ) : backoff_factor used by DynamicGradScaler . Defaults to 0.5 .
growth_interval ( float , optional ) : growth_interval used by DynamicGradScaler . Defaults to 1000.
hysteresis ( float , optional ) : hysteresis used by DynamicGradScaler . Defaults to 2.
max_scale ( int , optional ) : max_scale used by DynamicGradScaler . Defaults to 2 * * 32.
max_norm ( float , optional ) : max_norm used for ` clip_grad_norm ` . You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP . The ZeRO optimizer will take care of clip_grad_norm .
norm_type ( float , optional ) : norm_type used for ` clip_grad_norm ` .
2023-11-16 13:03:04 +00:00
tp_size ( int , optional ) : If ' tp_size ' is set to be greater than 1 , it means using tensor parallelism strategy , which is implemented in Shardformer , ' tp_size ' determines the size of the tensor parallel process group . Default to 1.
extra_dp_size ( int , optional ) : If ' extra_dp_size ' is set to be greater than 1 , it means creating another group to run with a ddp - like strategy . Default to 1.
2023-11-10 02:15:16 +00:00
enable_all_optimization ( bool , optional ) : Whether to switch on all the optimizations supported by Shardformer .
Currently all the optimization methods include fused normalization , flash attention and JIT .
Defaults to False .
enable_fused_normalization ( bool , optional ) : Whether to switch on fused normalization in Shardformer . Defaults to False .
enable_flash_attention ( bool , optional ) : Whether to switch on flash attention in Shardformer . Defaults to False .
enable_jit_fused ( bool , optional ) : Whether to switch on JIT in Shardformer . Default to False .
enable_sequence_parallelism ( bool ) : Whether to turn on sequence parallelism in Shardformer . Defaults to False .
enable_sequence_overlap ( bool ) : Whether to turn on sequence overlap in Shardformer . Defaults to False .
2023-04-17 03:25:35 +00:00
verbose ( bool , optional ) : verbose mode . Debug info including chunk search result will be printed . Defaults to False .
2023-03-31 08:06:13 +00:00
"""
def __init__ (
self ,
2023-08-24 01:29:25 +00:00
chunk_config_dict : Optional [ dict ] = None ,
chunk_init_device : Optional [ torch . device ] = None ,
placement_policy : str = " static " ,
2023-10-17 06:07:21 +00:00
enable_gradient_accumulation : bool = False ,
2024-05-16 05:20:00 +00:00
max_prefetch : int = 0 ,
2023-09-19 06:20:26 +00:00
shard_param_frac : float = 1.0 , # only for static placement
offload_optim_frac : float = 0.0 , # only for static placement
offload_param_frac : float = 0.0 , # only for static placement
warmup_non_model_data_ratio : float = 0.8 , # only for auto placement
steady_cuda_cap_ratio : float = 0.9 , # only for auto placement
2023-06-05 07:58:31 +00:00
precision : str = " fp16 " ,
2023-10-12 02:39:08 +00:00
master_weights : bool = True ,
2023-03-31 08:06:13 +00:00
pin_memory : bool = False ,
force_outputs_fp32 : bool = False ,
strict_ddp_mode : bool = False ,
2023-06-25 05:34:15 +00:00
search_range_m : int = 32 ,
2023-03-31 08:06:13 +00:00
hidden_dim : Optional [ int ] = None ,
2023-06-25 05:34:15 +00:00
min_chunk_size_m : float = 32 ,
2023-03-31 08:06:13 +00:00
memstats : Optional [ MemStats ] = None ,
gpu_margin_mem_ratio : float = 0.0 ,
2023-07-07 08:33:06 +00:00
initial_scale : float = 2 * * 16 ,
2023-03-31 08:06:13 +00:00
min_scale : float = 1 ,
growth_factor : float = 2 ,
backoff_factor : float = 0.5 ,
growth_interval : int = 1000 ,
hysteresis : int = 2 ,
max_scale : float = 2 * * 32 ,
max_norm : float = 0.0 ,
norm_type : float = 2.0 ,
2023-11-10 02:15:16 +00:00
tp_size : int = 1 ,
2023-11-20 11:46:43 +00:00
extra_dp_size : int = 1 ,
2023-11-10 02:15:16 +00:00
enable_all_optimization : bool = False ,
enable_fused_normalization : bool = False ,
enable_flash_attention : bool = False ,
enable_sequence_parallelism : bool = False ,
enable_jit_fused : bool = False ,
enable_sequence_overlap : bool = False ,
2024-05-24 02:31:16 +00:00
enable_async_reduce : bool = True ,
2023-04-17 03:25:35 +00:00
verbose : bool = False ,
2023-03-31 08:06:13 +00:00
) - > None :
2023-05-05 11:36:10 +00:00
super ( ) . __init__ ( )
2023-09-19 06:20:26 +00:00
assert precision in SUPPORTED_PRECISION , f " precision { precision } is not supported "
2024-01-09 06:57:07 +00:00
if get_accelerator ( ) . name == " npu " :
2023-11-20 08:12:41 +00:00
assert placement_policy == " static " , " NPU only supports static placement policy "
2024-06-26 07:52:09 +00:00
if enable_async_reduce and not pin_memory :
2024-06-05 06:23:13 +00:00
logging . warning (
2024-06-26 07:52:09 +00:00
f " enable_async_reduce sets pin_memory=True to achieve best performance, which is not implicitly set. "
2024-06-05 06:23:13 +00:00
)
pin_memory = True
2023-03-31 08:06:13 +00:00
self . gemini_config = dict (
2023-08-24 01:29:25 +00:00
chunk_config_dict = chunk_config_dict ,
2024-01-09 02:20:05 +00:00
chunk_init_device = ( chunk_init_device or get_accelerator ( ) . get_current_device ( ) ) ,
2023-03-31 08:06:13 +00:00
placement_policy = placement_policy ,
2023-10-17 06:07:21 +00:00
enable_gradient_accumulation = enable_gradient_accumulation ,
2023-08-24 01:29:25 +00:00
shard_param_frac = shard_param_frac ,
offload_optim_frac = offload_optim_frac ,
offload_param_frac = offload_param_frac ,
warmup_non_model_data_ratio = warmup_non_model_data_ratio ,
steady_cuda_cap_ratio = steady_cuda_cap_ratio ,
2023-03-31 08:06:13 +00:00
pin_memory = pin_memory ,
force_outputs_fp32 = force_outputs_fp32 ,
strict_ddp_mode = strict_ddp_mode ,
2023-06-25 05:34:15 +00:00
search_range_m = search_range_m ,
2023-03-31 08:06:13 +00:00
hidden_dim = hidden_dim ,
2023-06-25 05:34:15 +00:00
min_chunk_size_m = min_chunk_size_m ,
2023-03-31 08:06:13 +00:00
memstats = memstats ,
2023-06-05 07:58:31 +00:00
mixed_precision = PRECISION_STR_TO_DTYPE [ precision ] ,
2023-10-12 02:39:08 +00:00
master_weights = master_weights ,
2024-05-16 04:45:06 +00:00
max_prefetch = max_prefetch ,
2024-05-24 02:31:16 +00:00
enable_async_reduce = enable_async_reduce ,
2023-03-31 08:06:13 +00:00
)
2023-09-19 06:20:26 +00:00
self . zero_optim_config = dict (
gpu_margin_mem_ratio = gpu_margin_mem_ratio ,
)
self . optim_kwargs = dict (
initial_scale = initial_scale ,
growth_factor = growth_factor ,
backoff_factor = backoff_factor ,
growth_interval = growth_interval ,
hysteresis = hysteresis ,
min_scale = min_scale ,
max_scale = max_scale ,
max_norm = max_norm ,
norm_type = norm_type ,
)
2023-11-16 13:03:04 +00:00
self . enable_tensor_parallelism = tp_size > 1
2023-11-10 02:15:16 +00:00
self . enable_all_optimization = enable_all_optimization
self . enable_fused_normalization = enable_fused_normalization
self . enable_flash_attention = enable_flash_attention
self . enable_sequence_parallelism = enable_sequence_parallelism if self . enable_tensor_parallelism else False
self . enable_jit_fused = enable_jit_fused
self . enable_sequence_overlap = enable_sequence_overlap
2023-04-17 03:25:35 +00:00
self . verbose = verbose
2023-03-31 08:06:13 +00:00
2023-11-16 13:03:04 +00:00
self . tp_size = tp_size
self . extra_dp_size = extra_dp_size
world_size = dist . get_world_size ( )
self . zero_size = world_size / / ( self . tp_size * self . extra_dp_size )
2023-11-20 11:46:43 +00:00
assert (
world_size == ( self . tp_size * self . extra_dp_size ) * self . zero_size
) , f " The global group size can ' t be evenly divided by the subgroup size. "
2023-11-16 13:03:04 +00:00
self . pg_mesh = ProcessGroupMesh ( self . zero_size , self . extra_dp_size , self . tp_size )
2023-11-20 11:46:43 +00:00
self . zero_group = (
self . pg_mesh . get_group_along_axis ( ZERO_AXIS ) if self . zero_size < world_size else _get_default_group ( )
)
2023-11-16 13:03:04 +00:00
self . extra_dp_group = self . pg_mesh . get_group_along_axis ( DP_AXIS ) if self . extra_dp_size > 1 else None
self . tp_group = self . pg_mesh . get_group_along_axis ( TP_AXIS ) if self . tp_size > 1 else None
2024-04-23 06:12:20 +00:00
self . dp_size = self . zero_size * self . extra_dp_size
2023-11-16 13:03:04 +00:00
2023-11-10 02:15:16 +00:00
self . shard_config = ShardConfig (
tensor_parallel_process_group = self . tp_group ,
enable_tensor_parallelism = self . enable_tensor_parallelism ,
enable_all_optimization = self . enable_all_optimization ,
enable_fused_normalization = self . enable_fused_normalization ,
enable_flash_attention = self . enable_flash_attention ,
enable_jit_fused = self . enable_jit_fused ,
enable_sequence_parallelism = self . enable_sequence_parallelism ,
enable_sequence_overlap = self . enable_sequence_overlap ,
)
2024-01-03 06:26:13 +00:00
def __del__ ( self ) :
2024-03-05 13:52:30 +00:00
""" Destroy the process groups in ProcessGroupMesh """
2024-01-03 06:26:13 +00:00
self . pg_mesh . destroy_mesh_process_groups ( )
2023-03-31 08:06:13 +00:00
def support_no_sync ( self ) - > bool :
return False
2023-10-31 07:19:37 +00:00
def support_lora ( self ) - > bool :
return False
2023-03-31 08:06:13 +00:00
def control_precision ( self ) - > bool :
return True
def supported_precisions ( self ) - > List [ str ] :
2023-06-05 07:58:31 +00:00
return SUPPORTED_PRECISION
2023-03-31 08:06:13 +00:00
def control_device ( self ) - > bool :
return True
def supported_devices ( self ) - > List [ str ] :
2023-11-20 08:12:41 +00:00
return [ " cuda " , " npu " ]
2023-03-31 08:06:13 +00:00
2023-12-08 03:10:51 +00:00
def prepare_dataloader (
2024-02-05 07:14:56 +00:00
self ,
dataset ,
batch_size ,
shuffle = False ,
seed = 1024 ,
drop_last = False ,
pin_memory = False ,
num_workers = 0 ,
distributed_sampler_cls = None ,
* * kwargs ,
2023-12-08 03:10:51 +00:00
) :
r """
Prepare a dataloader for distributed training . The dataloader will be wrapped by
` torch . utils . data . DataLoader ` and ` torch . utils . data . DistributedSampler ` .
Args :
dataset ( ` torch . utils . data . Dataset ` ) : The dataset to be loaded .
shuffle ( bool , optional ) : Whether to shuffle the dataset . Defaults to False .
seed ( int , optional ) : Random worker seed for sampling , defaults to 1024.
add_sampler : Whether to add ` ` DistributedDataParallelSampler ` ` to the dataset . Defaults to True .
drop_last ( bool , optional ) : Set to True to drop the last incomplete batch , if the dataset size
is not divisible by the batch size . If False and the size of dataset is not divisible by
the batch size , then the last batch will be smaller , defaults to False .
pin_memory ( bool , optional ) : Whether to pin memory address in CPU memory . Defaults to False .
num_workers ( int , optional ) : Number of worker threads for this dataloader . Defaults to 0.
kwargs ( dict ) : optional parameters for ` ` torch . utils . data . DataLoader ` ` , more details could be found in
` DataLoader < https : / / pytorch . org / docs / stable / _modules / torch / utils / data / dataloader . html #DataLoader>`_.
Returns :
: class : ` torch . utils . data . DataLoader ` : A DataLoader used for training or testing .
"""
_kwargs = kwargs . copy ( )
zero_world_size = self . pg_mesh . size ( ZERO_AXIS )
extra_dp_world_size = self . pg_mesh . size ( DP_AXIS )
zero_rank = self . pg_mesh . coordinate ( ZERO_AXIS )
extra_dp_rank = self . pg_mesh . coordinate ( DP_AXIS )
2024-02-05 07:14:56 +00:00
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls (
2024-01-18 04:05:21 +00:00
dataset ,
num_replicas = zero_world_size * extra_dp_world_size ,
rank = zero_rank * extra_dp_world_size + extra_dp_rank ,
shuffle = shuffle ,
2023-12-08 03:10:51 +00:00
)
# Deterministic dataloader
def seed_worker ( worker_id ) :
worker_seed = seed
np . random . seed ( worker_seed )
torch . manual_seed ( worker_seed )
random . seed ( worker_seed )
return DataLoader (
dataset ,
batch_size = batch_size ,
sampler = sampler ,
worker_init_fn = seed_worker ,
drop_last = drop_last ,
pin_memory = pin_memory ,
num_workers = num_workers ,
* * _kwargs ,
)
2023-03-31 08:06:13 +00:00
def configure (
self ,
model : nn . Module ,
2023-06-15 09:38:42 +00:00
optimizer : Optional [ Optimizer ] = None ,
criterion : Optional [ Callable ] = None ,
dataloader : Optional [ DataLoader ] = None ,
lr_scheduler : Optional [ LRScheduler ] = None ,
) - > Tuple [ nn . Module , OptimizerWrapper , Callable , DataLoader , LRScheduler ] :
[shardformer] refactor embedding resize (#5603)
* [branch rebase] rebase main to Feature/resize_embedding (#5554)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [CI] run pre-commit (#5577)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
* run pre-commit
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [rebase] rebase main to resize-embedding (#5581)
* [release] grok-1 314b inference (#5490)
* [release] grok-1 inference
* [release] grok-1 inference
* [release] grok-1 inference
* [example] update Grok-1 inference (#5495)
* revise grok-1 example
* remove unused arg in scripts
* prevent re-installing torch
* update readme
* revert modifying colossalai requirements
* add perf
* trivial
* add tokenizer url
* [hotfix] set return_outputs=False in examples and polish code (#5404)
* fix: simplify merge_batch
* fix: use return_outputs=False to eliminate extra memory consumption
* feat: add return_outputs warning
* style: remove `return_outputs=False` as it is the default value
* [release] grok-1 inference benchmark (#5500)
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [shardformer]Fix lm parallel. (#5480)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* fix lm forward distribution
* fix
* test ci
* fix
* [fix] fix grok-1 example typo (#5506)
* [devops] fix example test ci (#5504)
* Fix ColoTensorSpec for py11 (#5440)
* fixed layout converter caching and updated tester
* Empty-Commit
* [shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462)
* [extension] update api
* [feature] add colo attention
* [feature] update sdpa
* [feature] update npu attention
* [feature] update flash-attn
* [test] add flash attn test
* [test] update flash attn test
* [shardformer] update modeling to fit colo attention (#5465)
* [misc] refactor folder structure
* [shardformer] update llama flash-attn
* [shardformer] fix llama policy
* [devops] update tensornvme install
* [test] update llama test
* [shardformer] update colo attn kernel dispatch
* [shardformer] update blip2
* [shardformer] update chatglm
* [shardformer] update gpt2
* [shardformer] update gptj
* [shardformer] update opt
* [shardformer] update vit
* [shardformer] update colo attention mask prep
* [shardformer] update whisper
* [test] fix shardformer tests (#5514)
* [test] fix shardformer tests
* [test] fix shardformer tests
* [format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: github-actions <github-actions@github.com>
* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution
* Change static methods for t5 layer distribution to member functions
* Change static methods for whisper layer distribution to member functions
* Replace whisper policy usage with self one
* Fix test case to use non-static layer distribution methods
* fix: fix typo
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)
* [fix] use tokenizer from the same pretrained path
* trust remote code
* [ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all
* fix and tested ppo
* 2 nd round refactor
* add ci tests
* fix ci
* fix ci
* fix readme, style
* fix readme style
* fix style, fix benchmark
* reproduce benchmark result, remove useless files
* rename to ColossalChat
* use new image
* fix ci workflow
* fix ci
* use local model/tokenizer for ci tests
* fix ci
* fix ci
* fix ci
* fix ci timeout
* fix rm progress bar. fix ci timeout
* fix ci
* fix ci typo
* remove 3d plugin from ci temporary
* test environment
* cannot save optimizer
* support chat template
* fix readme
* fix path
* test ci locally
* restore build_or_pr
* fix ci data path
* fix benchmark
* fix ci, move ci tests to 3080, disable fast tokenizer
* move ci to 85
* support flash attention 2
* add all-in-one data preparation script. Fix colossal-llama2-chat chat template
* add hardware requirements
* move ci test data
* fix save_model, add unwrap
* fix missing bos
* fix missing bos; support grad accumulation with gemini
* fix ci
* fix ci
* fix ci
* fix llama2 chat template config
* debug sft
* debug sft
* fix colossalai version requirement
* fix ci
* add sanity check to prevent NaN loss
* fix requirements
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* update readme
* update readme
* update readme and ignore
* fix logger bug
* support parallel_output
* modify data preparation logic
* fix tokenization
* update lr
* fix inference
* run pre-commit
---------
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`
* feat: apply `GradientCheckpointConfig` to policy and llama_forward
* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager
* fix: add optional args for `distribute_layer` and `get_stage_index`
* fix: fix changed API calls
* test: update llama tests
* style: polish `GradientCheckpointConfig`
* fix: fix pipeline utils tests
* fix incorrect sharding without zero (#5545)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization
* validate sequence parallel in llama (code to be polished)
* shardformer api writing
* integrate sequence parallel in ShardFormer
* fix pp bugs and sp bugs for LlaMa model
* integrating ring-based sequence parallelism into ShardFormer
* [sequence parallelism]: Add fused megatron function
* integrating ring-based sequence parallelism into ShardFormer
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* fix bugs when useing sp and flashattention together
* fix operation function name
* support flash attention for ulysses-style sp
* clarify sp process group
* fix compatibility bugs in moe plugin
* fix fused linear bugs
* fix linear layer test
* support gpt model all-to-all sp
* modify shard data dimension (meant to be dim=-1)
* support megtron-style sp and distributed attn for llama model
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* finish sp mode 3 support for gpt
* using all_to_all_single when batch size is 1
* support mode 2 sp in gpt2 (#5)
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* refactor ring implementation
* support mode 2 sp in gpt2
* polish code
* enable distributed attn mask when using sp mode 2 and 3 in llama
* automatically enable flash attn when using sp mode 2 and 3 in llama
* inplace attn mask
* add zero2 support for sequence parallel
* polish code
* fix bugs
* fix gemini checkpoint io
* loose tensor checking atol and rtol
* add comment
* fix llama layernorm grad
* fix zero grad
* fix zero grad
* fix conflict
* update split and gather auto grad func
* sequence parallel: inside text split (#6)
* polish code (part 1)
* polish code (part 2)
* polish code (part 2.5)
* polish code (part 3)
* sequence parallel: inside text split
* miscellaneous minor fixes
* polish code
* fix ulysses style ZeRO
* sequence parallel: inside text split
* miscellaneous minor fixes
* disaggregate sp group and dp group for sp
* fix llama and gpt sp
* polish code
* move ulysses grad sync to ddp (#9)
* remove zero_stage and unbind the grad sync for alltoall sp
* add 2d group creation test
* move ulysses grad sync to ddp
* add 2d group creation test
* remove useless code
* change shard config not to enable sp when enable_all_optimizations
* add sp warnings for several model
* remove useless code
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* [hotfix] quick fixes to make legacy tutorials runnable (#5559)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [fix] fix typo s/muiti-node /multi-node etc. (#5448)
* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)
* [devops] remove post commit ci (#5566)
* [devops] remove post commit ci
* [misc] run pre-commit on all files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---------
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [shardformer]enable padding vocabulary size. (#5489)
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* padding vocab
* padding vocabe
* fix
* fix
* fxi
* test ci
* fix
fix
fix
fix
* fix
fix
* fix
* fix
* Update hybrid_parallel_plugin.py
fix
fix
fix
* fix
fix
* fix
fix
* fix
* resolve super init
resolve super init
resolve super init
resolve super init
* resolve comments
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* vocab checkpointio
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
fix
fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* padding vocab
* fix
* fix
fix
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* cherry-pick
* revert moe modify
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
fix
fix
fix
fix
fix
fix
fix
* resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
* ptensor
ptensor
resolve comments
fix
fix
fix
fix
fix
resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix rebase
* fix rebase
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-18 08:10:18 +00:00
params_info = get_param_info ( optimizer )
2023-03-31 08:06:13 +00:00
if not isinstance ( model , ModelWrapper ) :
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
# In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16.
# This inconsistency of dtype will cause the error.
# We have two possible solutions:
# 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks.
# 2. patch sync bn or write a new on. This is relatively easy, but we need to test it.
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
2023-11-10 02:15:16 +00:00
if self . enable_tensor_parallelism :
shardformer = ShardFormer ( self . shard_config )
model , _ = shardformer . optimize ( model )
2023-11-20 11:46:43 +00:00
model = GeminiDDP (
model ,
* * self . gemini_config ,
zero_group = self . zero_group ,
extra_dp_group = self . extra_dp_group ,
verbose = self . verbose ,
)
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
if optimizer is not None and not isinstance ( optimizer , OptimizerWrapper ) :
optimizer = GeminiOptimizer (
2023-11-10 02:15:16 +00:00
optimizer ,
model ,
* * self . zero_optim_config ,
* * self . optim_kwargs ,
tp_group = self . tp_group ,
[shardformer] refactor embedding resize (#5603)
* [branch rebase] rebase main to Feature/resize_embedding (#5554)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [CI] run pre-commit (#5577)
* fix
* [release] update version (#5411)
* [hotfix] fix typo s/keywrods/keywords etc. (#5429)
* [devops] fix compatibility (#5444)
* [devops] fix compatibility
* [hotfix] update compatibility test on pr
* [devops] fix compatibility
* [devops] record duration during comp test
* [test] decrease test duration
* fix falcon
* [shardformer] fix gathering output when using tensor parallelism (#5431)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* [doc] release Open-Sora 1.0 with model weights (#5468)
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] release Open-Sora 1.0 with model weights
* [doc] update open-sora demo (#5479)
* [doc] update open-sora demo
* [doc] update open-sora demo
* [doc] update open-sora demo
* [example] add grok-1 inference (#5485)
* [misc] add submodule
* remove submodule
* [example] support grok-1 tp inference
* [example] add grok-1 inference script
* [example] refactor code
* [example] add grok-1 readme
* [exmaple] add test ci
* [exmaple] update readme
* run pre-commit
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* [rebase] rebase main to resize-embedding (#5581)
* [release] grok-1 314b inference (#5490)
* [release] grok-1 inference
* [release] grok-1 inference
* [release] grok-1 inference
* [example] update Grok-1 inference (#5495)
* revise grok-1 example
* remove unused arg in scripts
* prevent re-installing torch
* update readme
* revert modifying colossalai requirements
* add perf
* trivial
* add tokenizer url
* [hotfix] set return_outputs=False in examples and polish code (#5404)
* fix: simplify merge_batch
* fix: use return_outputs=False to eliminate extra memory consumption
* feat: add return_outputs warning
* style: remove `return_outputs=False` as it is the default value
* [release] grok-1 inference benchmark (#5500)
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [release] grok-1 inference benchmark
* [shardformer]Fix lm parallel. (#5480)
* fix
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* fix lm forward distribution
* fix
* test ci
* fix
* [fix] fix grok-1 example typo (#5506)
* [devops] fix example test ci (#5504)
* Fix ColoTensorSpec for py11 (#5440)
* fixed layout converter caching and updated tester
* Empty-Commit
* [shardformer] update colo attention to support custom mask (#5510)
* [feature] refactor colo attention (#5462)
* [extension] update api
* [feature] add colo attention
* [feature] update sdpa
* [feature] update npu attention
* [feature] update flash-attn
* [test] add flash attn test
* [test] update flash attn test
* [shardformer] update modeling to fit colo attention (#5465)
* [misc] refactor folder structure
* [shardformer] update llama flash-attn
* [shardformer] fix llama policy
* [devops] update tensornvme install
* [test] update llama test
* [shardformer] update colo attn kernel dispatch
* [shardformer] update blip2
* [shardformer] update chatglm
* [shardformer] update gpt2
* [shardformer] update gptj
* [shardformer] update opt
* [shardformer] update vit
* [shardformer] update colo attention mask prep
* [shardformer] update whisper
* [test] fix shardformer tests (#5514)
* [test] fix shardformer tests
* [test] fix shardformer tests
* [format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: github-actions <github-actions@github.com>
* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution
* Change static methods for t5 layer distribution to member functions
* Change static methods for whisper layer distribution to member functions
* Replace whisper policy usage with self one
* Fix test case to use non-static layer distribution methods
* fix: fix typo
---------
Co-authored-by: Wenhao Chen <cwher@outlook.com>
* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)
* [fix] use tokenizer from the same pretrained path
* trust remote code
* [ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all
* fix and tested ppo
* 2 nd round refactor
* add ci tests
* fix ci
* fix ci
* fix readme, style
* fix readme style
* fix style, fix benchmark
* reproduce benchmark result, remove useless files
* rename to ColossalChat
* use new image
* fix ci workflow
* fix ci
* use local model/tokenizer for ci tests
* fix ci
* fix ci
* fix ci
* fix ci timeout
* fix rm progress bar. fix ci timeout
* fix ci
* fix ci typo
* remove 3d plugin from ci temporary
* test environment
* cannot save optimizer
* support chat template
* fix readme
* fix path
* test ci locally
* restore build_or_pr
* fix ci data path
* fix benchmark
* fix ci, move ci tests to 3080, disable fast tokenizer
* move ci to 85
* support flash attention 2
* add all-in-one data preparation script. Fix colossal-llama2-chat chat template
* add hardware requirements
* move ci test data
* fix save_model, add unwrap
* fix missing bos
* fix missing bos; support grad accumulation with gemini
* fix ci
* fix ci
* fix ci
* fix llama2 chat template config
* debug sft
* debug sft
* fix colossalai version requirement
* fix ci
* add sanity check to prevent NaN loss
* fix requirements
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* add dummy data generation script
* update readme
* update readme
* update readme and ignore
* fix logger bug
* support parallel_output
* modify data preparation logic
* fix tokenization
* update lr
* fix inference
* run pre-commit
---------
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)
* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`
* feat: apply `GradientCheckpointConfig` to policy and llama_forward
* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager
* fix: add optional args for `distribute_layer` and `get_stage_index`
* fix: fix changed API calls
* test: update llama tests
* style: polish `GradientCheckpointConfig`
* fix: fix pipeline utils tests
* fix incorrect sharding without zero (#5545)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [shardformer] Sequence Parallelism Optimization (#5533)
* sequence parallel optimization
* validate sequence parallel in llama (code to be polished)
* shardformer api writing
* integrate sequence parallel in ShardFormer
* fix pp bugs and sp bugs for LlaMa model
* integrating ring-based sequence parallelism into ShardFormer
* [sequence parallelism]: Add fused megatron function
* integrating ring-based sequence parallelism into ShardFormer
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* fix bugs when useing sp and flashattention together
* fix operation function name
* support flash attention for ulysses-style sp
* clarify sp process group
* fix compatibility bugs in moe plugin
* fix fused linear bugs
* fix linear layer test
* support gpt model all-to-all sp
* modify shard data dimension (meant to be dim=-1)
* support megtron-style sp and distributed attn for llama model
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* finish sp mode 3 support for gpt
* using all_to_all_single when batch size is 1
* support mode 2 sp in gpt2 (#5)
* [shardformer] add megatron sp to llama
* support llama7B 128k with distributed attention
* [shardformer] robustness enhancement
* add block attn
* sp mode 1: keep input as a complete sequence
* fix sp compatability
* refactor ring implementation
* support mode 2 sp in gpt2
* polish code
* enable distributed attn mask when using sp mode 2 and 3 in llama
* automatically enable flash attn when using sp mode 2 and 3 in llama
* inplace attn mask
* add zero2 support for sequence parallel
* polish code
* fix bugs
* fix gemini checkpoint io
* loose tensor checking atol and rtol
* add comment
* fix llama layernorm grad
* fix zero grad
* fix zero grad
* fix conflict
* update split and gather auto grad func
* sequence parallel: inside text split (#6)
* polish code (part 1)
* polish code (part 2)
* polish code (part 2.5)
* polish code (part 3)
* sequence parallel: inside text split
* miscellaneous minor fixes
* polish code
* fix ulysses style ZeRO
* sequence parallel: inside text split
* miscellaneous minor fixes
* disaggregate sp group and dp group for sp
* fix llama and gpt sp
* polish code
* move ulysses grad sync to ddp (#9)
* remove zero_stage and unbind the grad sync for alltoall sp
* add 2d group creation test
* move ulysses grad sync to ddp
* add 2d group creation test
* remove useless code
* change shard config not to enable sp when enable_all_optimizations
* add sp warnings for several model
* remove useless code
---------
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
* [hotfix] quick fixes to make legacy tutorials runnable (#5559)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [fix] fix typo s/muiti-node /multi-node etc. (#5448)
* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)
* [devops] remove post commit ci (#5566)
* [devops] remove post commit ci
* [misc] run pre-commit on all files
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---------
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [shardformer]enable padding vocabulary size. (#5489)
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
* fix
fix
fix
* fix gather output
* fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* revert
* padding vocab
* padding vocabe
* fix
* fix
* fxi
* test ci
* fix
fix
fix
fix
* fix
fix
* fix
* fix
* Update hybrid_parallel_plugin.py
fix
fix
fix
* fix
fix
* fix
fix
* fix
* resolve super init
resolve super init
resolve super init
resolve super init
* resolve comments
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* vocab checkpointio
* padding vocab_size when using pipeline parallellism
padding vocab_size when using pipeline parallellism
fix
fix
* fix
fix
fix
* fix
* fix
fix resize embedding
fix resize embedding
* fix resize embedding
fix
* revert
* revert
* padding vocab
* fix
* fix
fix
* fix
fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix ci
* fix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* cherry-pick
* revert moe modify
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
fix
fix
fix
fix
fix
fix
fix
* resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
* ptensor
ptensor
resolve comments
fix
fix
fix
fix
fix
resolve comments
resolve comments
resolve comments
resolve comments
resolve comments
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix rebase
* fix rebase
---------
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-18 08:10:18 +00:00
params_info = params_info ,
2023-11-10 02:15:16 +00:00
verbose = self . verbose ,
2023-09-19 06:20:26 +00:00
)
2023-03-31 08:06:13 +00:00
return model , optimizer , criterion , dataloader , lr_scheduler
def control_checkpoint_io ( self ) - > bool :
return True
def get_checkpoint_io ( self ) - > CheckpointIO :
return GeminiCheckpointIO ( )
2023-05-09 03:10:02 +00:00
2023-07-04 04:00:33 +00:00
def no_sync ( self , model : nn . Module , optimizer : OptimizerWrapper ) - > Iterator [ None ] :
2023-11-20 08:12:41 +00:00
raise NotImplementedError
2023-10-31 07:19:37 +00:00
def enable_lora (
self , model : nn . Module , pretrained_dir : Optional [ str ] = None , lora_config : Optional [ Dict ] = None
) - > nn . Module :
raise NotImplementedError