mirror of https://github.com/hpcaitech/ColossalAI
[chat] fix gemini strategy (#4698)
* [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * g# This is a combination of 2 commits. [chat] fix gemini strategy fox * [chat] fix gemini strategy update llama2 example [chat] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * fix * fix * fix * fix * fix * Update train_prompts.pypull/4821/head
parent
bbbcac26e8
commit
be400a0936
|
@ -76,9 +76,9 @@ def main(args):
|
|||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_gemini_cpu":
|
||||
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif args.strategy == "colossalai_zero2_cpu":
|
||||
|
|
|
@ -30,3 +30,4 @@ class Actor(LoRAModule):
|
|||
"""Returns model output."""
|
||||
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||
return output
|
||||
|
||||
|
|
|
@ -71,11 +71,11 @@ def get_strategy_from_args(strategy: str):
|
|||
if strategy == "ddp":
|
||||
strategy_ = DDPStrategy()
|
||||
elif strategy == "colossalai_gemini":
|
||||
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
|
||||
strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
elif strategy == "colossalai_gemini_cpu":
|
||||
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
||||
elif strategy == "colossalai_zero2_cpu":
|
||||
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
|
|
|
@ -110,8 +110,8 @@ class Strategy(ABC):
|
|||
"""
|
||||
return model
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
|
||||
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
|
||||
def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
|
||||
self.booster.save_model(model, path, shard=shard, **kwargs)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
|
||||
self.booster.load_model(model, path, strict)
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch.nn as nn
|
|||
import colossalai
|
||||
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
|
||||
from colossalai.lazy.lazy_init import LazyInitContext
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
|
||||
|
||||
|
@ -130,6 +129,9 @@ class GeminiStrategy(DDPStrategy):
|
|||
seed: int = 42,
|
||||
shard_init: bool = False, # only for stage 3
|
||||
placement_policy: str = "auto",
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_m: int = 32, # only for stage 3
|
||||
|
@ -160,6 +162,9 @@ class GeminiStrategy(DDPStrategy):
|
|||
plugin_initializer = lambda: GeminiPlugin(
|
||||
chunk_init_device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
shard_param_frac=shard_param_frac,
|
||||
offload_optim_frac=offload_optim_frac,
|
||||
offload_param_frac=offload_param_frac,
|
||||
precision="fp16",
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
|
@ -188,7 +193,7 @@ class GeminiStrategy(DDPStrategy):
|
|||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
|
||||
def model_init_context(self):
|
||||
return LazyInitContext(default_device=get_current_device())
|
||||
return super().model_init_context()
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
assert isinstance(model, GeminiDDP)
|
||||
|
|
|
@ -87,9 +87,9 @@ class DDPStrategy(Strategy):
|
|||
return model.unwrap()
|
||||
|
||||
def save_pretrained(
|
||||
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
|
||||
self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
|
||||
) -> None:
|
||||
if not only_rank0 or dist.get_rank() == 0:
|
||||
if dist.get_rank() == 0:
|
||||
unwrapped_model = self.unwrap_model(model)
|
||||
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
|
||||
pretrained_model = unwrapped_model.model
|
||||
|
@ -98,19 +98,19 @@ class DDPStrategy(Strategy):
|
|||
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(path)
|
||||
model_path = os.path.join(path, "pytorch_model.bin")
|
||||
self.save_model(model, model_path, only_rank0=only_rank0)
|
||||
|
||||
model_path = os.path.join(path, "pytorch_model.bin")
|
||||
self.save_model(model, model_path, shard=shard)
|
||||
def _replace_keys(model_path: str, replace_fn: Callable):
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
|
||||
# HACK: rename keys of pytorch_model.bin
|
||||
if dist.get_rank() == 0:
|
||||
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
|
||||
|
||||
|
||||
def get_model_state_dict_shard(self, model: nn.Module, **config):
|
||||
# TODO: implement sharding on naive strategy
|
||||
model = self.unwrap_model(model)
|
||||
|
|
|
@ -24,7 +24,7 @@ def main(args):
|
|||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
|
||||
else:
|
||||
|
|
|
@ -24,7 +24,7 @@ def train(args):
|
|||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="cuda")
|
||||
strategy = GeminiStrategy(placement_policy="static")
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
pandas>=1.4.1
|
||||
sentencepiece
|
||||
colossalai>=0.3.1
|
||||
colossalai==0.3.3
|
||||
|
|
|
@ -23,7 +23,7 @@ def main(args):
|
|||
if args.strategy == "ddp":
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == "colossalai_gemini":
|
||||
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
|
||||
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
|
||||
elif args.strategy == "colossalai_zero2":
|
||||
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
|
||||
else:
|
||||
|
@ -33,6 +33,10 @@ def main(args):
|
|||
warnings.warn("LoRA weights should be merged with the model weights")
|
||||
state_dict = torch.load(args.rm_path, map_location="cpu")
|
||||
|
||||
if args.lora_rank > 0:
|
||||
warnings.warn("Lora is not supported yet.")
|
||||
args.lora_rank = 0
|
||||
|
||||
with strategy.model_init_context():
|
||||
# configure model
|
||||
if args.model == "gpt2":
|
||||
|
@ -199,7 +203,7 @@ def main(args):
|
|||
LORA_MANAGER.merge_weights = True
|
||||
actor.eval()
|
||||
# save model checkpoint after fitting
|
||||
strategy.save_model(actor, args.save_path, only_rank0=True)
|
||||
strategy.save_pretrained(actor, path=args.save_path)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import argparse
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -33,6 +34,10 @@ def train(args):
|
|||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
if args.lora_rank > 0:
|
||||
warnings.warn("Lora is not supported yet.")
|
||||
args.lora_rank = 0
|
||||
|
||||
with strategy.model_init_context():
|
||||
if args.model == "bloom":
|
||||
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
|
@ -165,7 +170,8 @@ def train(args):
|
|||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_model(model, args.save_path, only_rank0=True)
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, args.save_path)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(
|
||||
|
|
|
@ -40,8 +40,9 @@ def train(args):
|
|||
|
||||
# configure model
|
||||
if args.lora_rank > 0:
|
||||
warnings.warn("Gradient checkpoint is disabled when using LoRA")
|
||||
args.grad_checkpoint = False
|
||||
warnings.warn("Lora is not supported yet.")
|
||||
args.lora_rank = 0
|
||||
|
||||
with strategy.model_init_context():
|
||||
if args.model == "bloom":
|
||||
model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
|
||||
|
@ -184,7 +185,7 @@ def train(args):
|
|||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
pytest
|
||||
colossalai>=0.3.1
|
||||
colossalai==0.3.3
|
||||
|
|
|
@ -2,7 +2,7 @@ transformers>=4.20.1
|
|||
tqdm
|
||||
datasets
|
||||
loralib
|
||||
colossalai>=0.3.1
|
||||
colossalai==0.3.3
|
||||
torch<2.0.0, >=1.12.1
|
||||
langchain
|
||||
tokenizers
|
||||
|
|
|
@ -57,9 +57,9 @@ def run_test_checkpoint(strategy_name: str, shard: bool):
|
|||
rank0_dirname = rank0_dirname[0]
|
||||
|
||||
model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
|
||||
strategy.save_model(actor, model_path, only_rank0=not shard)
|
||||
strategy.save_model(actor, model_path)
|
||||
optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
|
||||
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
|
||||
strategy.save_optimizer(actor_optim, optim_path)
|
||||
dist.barrier()
|
||||
|
||||
strategy.load_model(actor, model_path, strict=False)
|
||||
|
|
|
@ -41,6 +41,7 @@ MODELS_DIR=$BASE_DIR/examples/models_config
|
|||
MODELS=('gpt2' 'bloom' 'opt' 'llama')
|
||||
STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')
|
||||
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
# install requirements
|
||||
|
@ -80,13 +81,10 @@ SKIPPED_TESTS=(
|
|||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
"gpt2-colossalai_gemini"
|
||||
"opt-colossalai_gemini"
|
||||
"bloom-colossalai_gemini"
|
||||
)
|
||||
|
||||
GRAD_CKPTS=('' '--grad_checkpoint')
|
||||
for lora_rank in '0' '4'; do
|
||||
for lora_rank in '0'; do
|
||||
for model in ${MODELS[@]}; do
|
||||
strategies=($(shuf -e "${STRATEGIES[@]}"))
|
||||
for strategy in ${strategies[@]}; do
|
||||
|
@ -135,14 +133,11 @@ SKIPPED_TESTS=(
|
|||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
"gpt2-colossalai_gemini"
|
||||
"opt-colossalai_gemini"
|
||||
"bloom-colossalai_gemini"
|
||||
)
|
||||
|
||||
LOSS_FNS=('log_sig' 'log_exp')
|
||||
DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
|
||||
for lora_rank in '0' '4'; do
|
||||
for lora_rank in '0'; do
|
||||
for model in ${MODELS[@]}; do
|
||||
strategies=($(shuf -e "${STRATEGIES[@]}"))
|
||||
for strategy in ${strategies[@]}; do
|
||||
|
@ -193,13 +188,10 @@ SKIPPED_TESTS=(
|
|||
"llama-ddp"
|
||||
"llama-colossalai_gemini"
|
||||
"llama-colossalai_zero2"
|
||||
"gpt2-colossalai_gemini"
|
||||
"opt-colossalai_gemini"
|
||||
"bloom-colossalai_gemini"
|
||||
)
|
||||
|
||||
for model in ${MODELS[@]}; do
|
||||
for lora_rank in '0' '4'; do
|
||||
for lora_rank in '0'; do
|
||||
strategies=($(shuf -e "${STRATEGIES[@]}"))
|
||||
for strategy in ${strategies[@]}; do
|
||||
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
|
||||
|
@ -223,7 +215,7 @@ for model in ${MODELS[@]}; do
|
|||
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
|
||||
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
|
||||
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
|
||||
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
|
||||
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
break
|
||||
|
@ -238,4 +230,4 @@ for model in ${MODELS[@]}; do
|
|||
rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
|
||||
done
|
||||
done
|
||||
rm $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
|
||||
rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
|
||||
|
|
Loading…
Reference in New Issue