[chat] fix gemini strategy (#4698)

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* [chat] fix gemini strategy

* g# This is a combination of 2 commits.

[chat] fix gemini strategy

fox

* [chat] fix gemini strategy

update llama2 example

[chat] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* [fix] fix gemini strategy

* fix

* fix

* fix

* fix

* fix

* Update train_prompts.py
pull/4821/head
flybird11111 2023-09-27 13:15:32 +08:00 committed by GitHub
parent bbbcac26e8
commit be400a0936
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 49 additions and 40 deletions

View File

@ -76,9 +76,9 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static",initial_scale=2**5)
elif args.strategy == "colossalai_gemini_cpu":
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == "colossalai_zero2_cpu":

View File

@ -30,3 +30,4 @@ class Actor(LoRAModule):
"""Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
return output

View File

@ -71,11 +71,11 @@ def get_strategy_from_args(strategy: str):
if strategy == "ddp":
strategy_ = DDPStrategy()
elif strategy == "colossalai_gemini":
strategy_ = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif strategy == "colossalai_zero2":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif strategy == "colossalai_gemini_cpu":
strategy_ = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy_ = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif strategy == "colossalai_zero2_cpu":
strategy_ = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:

View File

@ -110,8 +110,8 @@ class Strategy(ABC):
"""
return model
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
self.booster.save_model(model, path, shard=shard, **kwargs)
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)

View File

@ -6,7 +6,6 @@ import torch.nn as nn
import colossalai
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.utils import get_current_device
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
@ -130,6 +129,9 @@ class GeminiStrategy(DDPStrategy):
seed: int = 42,
shard_init: bool = False, # only for stage 3
placement_policy: str = "auto",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
search_range_m: int = 32, # only for stage 3
@ -160,6 +162,9 @@ class GeminiStrategy(DDPStrategy):
plugin_initializer = lambda: GeminiPlugin(
chunk_init_device=get_current_device(),
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
precision="fp16",
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
@ -188,7 +193,7 @@ class GeminiStrategy(DDPStrategy):
colossalai.launch_from_torch({}, seed=self.seed)
def model_init_context(self):
return LazyInitContext(default_device=get_current_device())
return super().model_init_context()
def unwrap_model(self, model: nn.Module) -> nn.Module:
assert isinstance(model, GeminiDDP)

View File

@ -87,9 +87,9 @@ class DDPStrategy(Strategy):
return model.unwrap()
def save_pretrained(
self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
if not only_rank0 or dist.get_rank() == 0:
if dist.get_rank() == 0:
unwrapped_model = self.unwrap_model(model)
assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
pretrained_model = unwrapped_model.model
@ -98,19 +98,19 @@ class DDPStrategy(Strategy):
pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
if tokenizer is not None:
tokenizer.save_pretrained(path)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, only_rank0=only_rank0)
model_path = os.path.join(path, "pytorch_model.bin")
self.save_model(model, model_path, shard=shard)
def _replace_keys(model_path: str, replace_fn: Callable):
state_dict = torch.load(model_path, map_location="cpu")
state_dict = {replace_fn(k): v for k, v in state_dict.items()}
torch.save(state_dict, model_path)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if dist.get_rank() == 0:
_replace_keys(model_path, lambda k: k.replace("model.", "", 1))
def get_model_state_dict_shard(self, model: nn.Module, **config):
# TODO: implement sharding on naive strategy
model = self.unwrap_model(model)

View File

@ -24,7 +24,7 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", offload_optim_frac=1.0, offload_param_frac=1.0, initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
else:

View File

@ -24,7 +24,7 @@ def train(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="cuda")
strategy = GeminiStrategy(placement_policy="static")
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:

View File

@ -1,3 +1,3 @@
pandas>=1.4.1
sentencepiece
colossalai>=0.3.1
colossalai==0.3.3

View File

@ -23,7 +23,7 @@ def main(args):
if args.strategy == "ddp":
strategy = DDPStrategy()
elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy="auto", initial_scale=2**5)
strategy = GeminiStrategy(placement_policy="static", initial_scale=2**5)
elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
else:
@ -33,6 +33,10 @@ def main(args):
warnings.warn("LoRA weights should be merged with the model weights")
state_dict = torch.load(args.rm_path, map_location="cpu")
if args.lora_rank > 0:
warnings.warn("Lora is not supported yet.")
args.lora_rank = 0
with strategy.model_init_context():
# configure model
if args.model == "gpt2":
@ -199,7 +203,7 @@ def main(args):
LORA_MANAGER.merge_weights = True
actor.eval()
# save model checkpoint after fitting
strategy.save_model(actor, args.save_path, only_rank0=True)
strategy.save_pretrained(actor, path=args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(

View File

@ -1,4 +1,5 @@
import argparse
import warnings
import torch
import torch.distributed as dist
@ -33,6 +34,10 @@ def train(args):
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
if args.lora_rank > 0:
warnings.warn("Lora is not supported yet.")
args.lora_rank = 0
with strategy.model_init_context():
if args.model == "bloom":
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank)
@ -165,7 +170,8 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
strategy.save_model(model, args.save_path, only_rank0=True)
state_dict = model.state_dict()
torch.save(state_dict, args.save_path)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(

View File

@ -40,8 +40,9 @@ def train(args):
# configure model
if args.lora_rank > 0:
warnings.warn("Gradient checkpoint is disabled when using LoRA")
args.grad_checkpoint = False
warnings.warn("Lora is not supported yet.")
args.lora_rank = 0
with strategy.model_init_context():
if args.model == "bloom":
model = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank, checkpoint=args.grad_checkpoint)
@ -184,7 +185,7 @@ def train(args):
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
strategy.save_pretrained(model, path=args.save_path, only_rank0=True, tokenizer=tokenizer)
strategy.save_pretrained(model, path=args.save_path, tokenizer=tokenizer)
# save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(

View File

@ -1,2 +1,2 @@
pytest
colossalai>=0.3.1
colossalai==0.3.3

View File

@ -2,7 +2,7 @@ transformers>=4.20.1
tqdm
datasets
loralib
colossalai>=0.3.1
colossalai==0.3.3
torch<2.0.0, >=1.12.1
langchain
tokenizers

View File

@ -57,9 +57,9 @@ def run_test_checkpoint(strategy_name: str, shard: bool):
rank0_dirname = rank0_dirname[0]
model_path = os.path.join(rank0_dirname, "model" if shard else f"model.pt")
strategy.save_model(actor, model_path, only_rank0=not shard)
strategy.save_model(actor, model_path)
optim_path = os.path.join(rank0_dirname, "optim" if shard else "optim.pt")
strategy.save_optimizer(actor_optim, optim_path, only_rank0=not shard)
strategy.save_optimizer(actor_optim, optim_path)
dist.barrier()
strategy.load_model(actor, model_path, strict=False)

View File

@ -41,6 +41,7 @@ MODELS_DIR=$BASE_DIR/examples/models_config
MODELS=('gpt2' 'bloom' 'opt' 'llama')
STRATEGIES=('ddp' 'colossalai_gemini' 'colossalai_zero2')
export OMP_NUM_THREADS=8
# install requirements
@ -80,13 +81,10 @@ SKIPPED_TESTS=(
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
)
GRAD_CKPTS=('' '--grad_checkpoint')
for lora_rank in '0' '4'; do
for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
@ -135,14 +133,11 @@ SKIPPED_TESTS=(
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
)
LOSS_FNS=('log_sig' 'log_exp')
DATASETS=('Anthropic/hh-rlhf' 'Dahoas/rm-static')
for lora_rank in '0' '4'; do
for lora_rank in '0'; do
for model in ${MODELS[@]}; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
@ -193,13 +188,10 @@ SKIPPED_TESTS=(
"llama-ddp"
"llama-colossalai_gemini"
"llama-colossalai_zero2"
"gpt2-colossalai_gemini"
"opt-colossalai_gemini"
"bloom-colossalai_gemini"
)
for model in ${MODELS[@]}; do
for lora_rank in '0' '4'; do
for lora_rank in '0'; do
strategies=($(shuf -e "${STRATEGIES[@]}"))
for strategy in ${strategies[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$strategy-$lora_rank " ]]; then
@ -223,7 +215,7 @@ for model in ${MODELS[@]}; do
--experience_batch_size 2 --train_batch_size 1 --lora_rank $lora_rank \
--pretrain $EXAMPLES_DIR/rlhf_models/sft_ckpt_${model}_${lora_rank} \
$rm_pretrain_model --rm_path $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt \
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
--save_path $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts
passed=$?
if [ $passed -eq 0 ]; then
break
@ -238,4 +230,4 @@ for model in ${MODELS[@]}; do
rm $EXAMPLES_DIR/rlhf_models/rm_ckpt_${model}_${lora_rank}.pt
done
done
rm $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts.pt
rm -rf $EXAMPLES_DIR/rlhf_models/actor_checkpoint_prompts