[chat] ChatGPT train prompts on ray example (#3309)

* [feat][chatgpt]train prompts on ray example

* [fix]simplify code

* [fix]remove depreciated parameter

* [fix]add dependencies

* [fix]method calling

* [fix]experience maker

* [fix]missing loss function

* [fix]init optimizer

* [feat]add usage comment

* [fix]rename files

* [fix]add readme

* [fix]file path

* [fix]move directory

---------

Co-authored-by: jiangwen <zxl265370@antgroup.com>
pull/3563/head
MisterLin1995 2023-04-13 18:18:36 +08:00 committed by GitHub
parent 535b896435
commit 1a809eddaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 597 additions and 2 deletions

View File

@ -1,6 +1,6 @@
# Community Examples
---
We are thrilled to announce the latest updates to ColossalChat, an open-source solution for cloning ChatGPT with a complete RLHF (Reinforcement Learning with Human Feedback) pipeline.
We are thrilled to announce the latest updates to ColossalChat, an open-source solution for cloning ChatGPT with a complete RLHF (Reinforcement Learning with Human Feedback) pipeline.
As Colossal-AI undergoes major updates, we are actively maintaining ColossalChat to stay aligned with the project's progress. With the introduction of Community-driven example, we aim to create a collaborative platform for developers to contribute exotic features built on top of ColossalChat.
@ -16,7 +16,8 @@ Community examples consist of both inference and training examples that have bee
| Example | Description | Code Example | Colab | Author |
|:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:|
| Peft | Adding Peft support for SFT and Prompts model training | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/peft) | - | [YY Lin](https://github.com/yynil) |
| Peft | Adding Peft support for SFT and Prompts model training | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/peft) | - | [YY Lin](https://github.com/yynil) |
| Train prompts on Ray | A Ray based implementation of Train prompts example | [Huggingface Peft](https://github.com/hpcaitech/ColossalAI/tree/main/applications/Chat/examples/community/ray) | - | [MisterLin1995](https://github.com/MisterLin1995) |
|...|...|...|...|...|
### How to get involved

View File

@ -0,0 +1,17 @@
# ColossalAI on Ray
## Abstract
This is an experimental effort to run ColossalAI Chat training on Ray
## How to use?
### 1. Setup Ray clusters
Please follow the official [Ray cluster setup instructions](https://docs.ray.io/en/latest/cluster/getting-started.html) to setup an cluster with GPU support. Record the cluster's api server endpoint, it should be something similar to http://your.head.node.addrees:8265
### 2. Clone repo
Clone this project:
```shell
git clone https://github.com/hpcaitech/ColossalAI.git
```
### 3. Submit the ray job
```shell
python applications/Chat/examples/community/ray/ray_job_script.py http://your.head.node.addrees:8265
```
### 4. View your job on the Ray Dashboard
Open your ray cluster dashboard http://your.head.node.addrees:8265 to view your submitted training job.

View File

@ -0,0 +1,22 @@
import sys
from ray.job_submission import JobSubmissionClient
def main(api_server_endpoint="http://127.0.0.1:8265"):
client = JobSubmissionClient(api_server_endpoint)
client.submit_job(
entrypoint=
"python experimental/ray/train_prompts_on_ray.py --strategy colossalai_zero2 --prompt_csv_url https://huggingface.co/datasets/fka/awesome-chatgpt-prompts/resolve/main/prompts.csv",
runtime_env={
"working_dir":
"applications/Chat",
"pip": [
"torch==1.13.1", "transformers>=4.20.1", "datasets", "loralib", "colossalai>=0.2.4", "langchain",
"tokenizers", "fastapi", "sse_starlette", "wandb", "sentencepiece", "gpustat"
]
})
if __name__ == "__main__":
main(sys.argv[1])

View File

@ -0,0 +1,555 @@
import argparse
import logging
import os
import socket
from copy import deepcopy
from typing import Type
import ray
import torch
from coati.experience_maker.base import Experience
from coati.models.base import RewardModel
from coati.models.bloom import BLOOMActor, BLOOMCritic
from coati.models.gpt import GPTActor, GPTCritic
from coati.models.lora import LoRAModule
from coati.models.loss import PolicyLoss, ValueLoss
from coati.models.opt import OPTActor, OPTCritic
from coati.models.utils import compute_reward
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
class ExperienceCompositionRefs:
def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef,
base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None:
self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref
self.action_log_probs_ref = action_log_probs_ref
self.base_action_log_probs_ref = base_action_log_probs_ref
self.value_ref = value_ref
self.r_ref = r_ref
class ExperienceMaker:
def __init__(self, kl_coef) -> None:
self.kl_coef = kl_coef
@torch.no_grad()
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs):
sequences, attention_mask, action_mask = ray.get(
experiment_computation_refs.sequences_attention_mask_action_mask_ref)
action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref)
base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref)
r = ray.get(experiment_computation_refs.r_ref)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
value = ray.get(experiment_computation_refs.value_ref)
advantage = reward - value
if advantage.ndim == 1:
advantage = advantage.unsqueeze(-1)
experience = Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask)
return experience
class DistributedTorchRayActor:
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
self._model = None
self._world_size = world_size
self._rank = rank
self._local_rank = local_rank
self._master_addr = master_addr if master_addr else self._get_current_node_ip()
self._master_port = master_port if master_port else self._get_free_port()
os.environ["MASTER_ADDR"] = self._master_addr
os.environ["MASTER_PORT"] = str(self._master_port)
os.environ["WORLD_SIZE"] = str(self._world_size)
os.environ["RANK"] = str(self._rank)
os.environ["LOCAL_RANK"] = str(self._local_rank)
@staticmethod
def _get_current_node_ip():
return ray._private.services.get_node_ip_address()
@staticmethod
def _get_free_port():
with socket.socket() as sock:
sock.bind(('', 0))
return sock.getsockname()[1]
def get_master_addr_port(self):
return self._master_addr, self._master_port
class BasePPORole(DistributedTorchRayActor):
def add_experience_maker(self, kl_coef: float = 0.1):
self._experience_maker = ExperienceMaker(kl_coef)
def make_experience(self, experience_computation_ref: ExperienceCompositionRefs):
return self._experience_maker.make_experience(experience_computation_ref)
def _init_strategy(self, strategy: str):
# configure strategy
if strategy == 'naive':
self._strategy = NaiveStrategy()
elif strategy == 'ddp':
self._strategy = DDPStrategy()
elif strategy == 'colossalai_gemini':
self._strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
elif strategy == 'colossalai_zero2':
self._strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{strategy}"')
def _init_optimizer(self):
if isinstance(self._strategy, ColossalAIStrategy):
self._optimizer = HybridAdam(self._model.parameters(), lr=5e-6)
else:
self._optimizer = Adam(self._model.parameters(), lr=5e-6)
def _prepare_model_with_strategy(self, has_optimizer: bool):
if has_optimizer:
self._init_optimizer()
(self._model, self._optimizer) = self._strategy.prepare((self._model, self._optimizer))
else:
self._model = self._strategy.prepare(self._model)
def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str):
raise NotImplementedError()
def init_model_from_pretrained(self,
strategy: str,
model_class: Type[LoRAModule],
pretrain: str,
has_optimizer=False):
self._init_strategy(strategy)
self._load_model_from_pretrained(model_class, pretrain)
self._prepare_model_with_strategy(has_optimizer)
def eval(self):
self._model.eval()
class TrainablePPORole(BasePPORole):
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
def _train(self):
self._model.train()
def _training_step(self, experience: Experience):
raise NotImplementedError()
def learn_on_experiences(self, experience_refs):
experiences = ray.get(experience_refs)
device = torch.cuda.current_device()
self._train()
for exp in experiences:
exp.to_device(device)
self._training_step(exp)
self.eval()
@ray.remote(num_gpus=1)
class RayPPOActor(TrainablePPORole):
def set_loss_function(self, eps_clip: float):
self._actor_loss_fn = PolicyLoss(eps_clip)
def load_tokenizer_from_pretrained(self, model_type: str, pretrained):
if model_type == 'gpt2':
self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
elif model_type == 'bloom':
self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained)
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
elif model_type == 'opt':
self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained)
else:
raise ValueError(f'Unsupported model "{model_type}"')
# Set tokenize function for sequence generation
def _text_input_tokenize_fn(texts):
batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
return {k: v.cuda() for k, v in batch.items()}
self._sample_tokenize_function = _text_input_tokenize_fn
def setup_generate_kwargs(self, generate_kwargs: dict):
from coati.trainer.ppo import _set_default_generate_kwargs
self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model)
self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id
self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id
def load_csv_prompt_file_from_url_to_sampler(self, prompt_url):
import pandas as pd
prompts = pd.read_csv(prompt_url)['prompt']
self._sampler = self._strategy.setup_sampler(prompts)
def _generate(self, input_ids, **generate_kwargs):
return self._model.generate(input_ids, return_action_mask=True, **generate_kwargs)
def sample_prompts_and_make_sequence(self, experience_batch_size):
sampled_prompts = self._sampler.sample(experience_batch_size)
input_ids = self._sample_tokenize_function(sampled_prompts)
if isinstance(input_ids, dict):
return self._generate(**input_ids, **self._generate_kwargs)
else:
return self._generate(input_ids, **self._generate_kwargs)
@torch.no_grad()
def calculate_action_log_probs(self, sequence_attention_action_mask):
sequences, attention_mask, action_mask = sequence_attention_action_mask
return self._model.forward(sequences, action_mask.size(1), attention_mask)
def _training_step(self, experience):
num_actions = experience.action_mask.size(1)
action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask)
actor_loss = self._actor_loss_fn(action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask)
self._strategy.backward(actor_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
logging.info("actor_loss: {}".format(actor_loss))
def save_checkpoint(self, save_path, should_save_optimizer: bool):
if self._rank == 0:
# save model checkpoint only on rank 0
self._strategy.save_model(self._model, save_path, only_rank0=True)
# save optimizer checkpoint on all ranks
if should_save_optimizer:
self._strategy.save_optimizer(self._optimizer,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False)
def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
encoded_input = self._model_tokenizer(prompt, return_tensors='pt')
input_ids = {k: v.cuda() for k, v in encoded_input.items()}
sequence, _ = self._model.generate(**input_ids,
max_length=max_length,
return_action_mask=False,
num_return_sequences=num_return_sequences)
token_list = list(sequence.data[0])
output = " ".join([self._model_tokenizer.decode(token) for token in token_list])
return output
@ray.remote(num_gpus=1)
class RayPPOCritic(TrainablePPORole):
def set_loss_function(self, value_clip: float):
self._critic_loss_fn = ValueLoss(value_clip)
def _training_step(self, experience):
values = self._model(experience.sequences,
action_mask=experience.action_mask,
attention_mask=experience.attention_mask)
critic_loss = self._critic_loss_fn(values,
experience.values,
experience.reward,
action_mask=experience.action_mask)
self._strategy.backward(critic_loss, self._model, self._optimizer)
self._strategy.optimizer_step(self._optimizer)
self._optimizer.zero_grad()
logging.info("critic_loss: {}".format(critic_loss))
@torch.no_grad()
def calculate_value(self, sequence_attention_action_mask):
sequences, attention_mask, action_mask = sequence_attention_action_mask
return self._model(sequences, action_mask, attention_mask)
@ray.remote(num_gpus=1)
class RayPPORewardModel(BasePPORole):
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
critic = model_class(pretrained=pretrain).to(torch.cuda.current_device())
self._model = RewardModel(deepcopy(critic.model),
deepcopy(critic.value_head)).to(torch.cuda.current_device())
@torch.no_grad()
def calculate_r(self, sequence_attention_action_mask):
sequences, attention_mask, _ = sequence_attention_action_mask
return self._model(sequences, attention_mask)
@ray.remote(num_gpus=1)
class RayPPOInitialModel(BasePPORole):
def _load_model_from_pretrained(self, model_class, pretrain):
with self._strategy.model_init_context():
self._model = model_class(pretrain).to(torch.cuda.current_device())
@torch.no_grad()
def calculate_base_action_log_probs(self, sequence_attention_action_mask):
sequences, attention_mask, action_mask = sequence_attention_action_mask
return self._model(sequences, action_mask.size(1), attention_mask)
class PPORayActorGroup:
"""
A group of ray actors
Functions start with 'async' should return list of object refs
"""
def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None:
self._num_nodes = num_nodes
self._num_gpus_per_node = num_gpus_per_node
self.ray_actor_type = ray_actor_type
self._initiate_actors()
def _initiate_actors(self):
world_size = self._num_nodes * self._num_gpus_per_node
# Use placement group to lock resources for models of same type
pg = None
if self._num_gpus_per_node > 1:
bundles = [{"GPU": self._num_gpus_per_node, "CPU": self._num_gpus_per_node} for _ in range(self._num_nodes)]
pg = placement_group(bundles, strategy="STRICT_SPREAD")
ray.get(pg.ready())
if pg:
master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None)
else:
master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None)
self._actor_handlers = [master_actor]
# Create worker actors
if world_size > 1:
master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote())
for rank in range(1, world_size):
local_rank = rank % self._num_gpus_per_node
if pg:
worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote(
world_size, rank, local_rank, master_addr, master_port)
else:
worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank,
master_addr, master_port)
self._actor_handlers.append(worker_actor)
def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str,
has_optimizer: bool):
return [
actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer)
for actor in self._actor_handlers
]
class TrainableModelRayActorGroup(PPORayActorGroup):
def async_learn_on_experiences(self, experience_refs):
num_actors = len(self._actor_handlers)
learn_result_refs = []
for i in range(num_actors):
exp_refs_batch = experience_refs[i::num_actors]
learn_result_refs.append(self._actor_handlers[i].learn_on_experiences.remote(exp_refs_batch))
return learn_result_refs
class PPOActorRayActorGroup(TrainableModelRayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOActor)
def async_prepare_for_sequence_generation(self, model: str, pretrain: str, generation_kwargs: dict):
refs = []
for actor in self._actor_handlers:
refs.append(actor.load_tokenizer_from_pretrained.remote(model, pretrain))
refs.append(actor.setup_generate_kwargs.remote(generation_kwargs))
return refs
def load_csv_prompt_file_from_url_to_sampler(self, csv_url):
ray.get([actor.load_csv_prompt_file_from_url_to_sampler.remote(csv_url) for actor in self._actor_handlers])
def async_sample_prompts_and_make_sequence(self, experience_batch_size):
return [actor.sample_prompts_and_make_sequence.remote(experience_batch_size) for actor in self._actor_handlers]
def async_calculate_action_log_probs(self, sequences_attention_mask_action_mask_refs):
num_actors = len(self._actor_handlers)
action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote(
sequences_attention_mask_action_mask_refs[i])
action_log_probs_refs.append(action_log_probs_ref)
return action_log_probs_refs
def set_loss_function(self, eps_clip: float = 0.2):
ray.get([actor.set_loss_function.remote(eps_clip) for actor in self._actor_handlers])
def save_checkpoint(self, save_path, should_save_optimizer):
ray.get([actor.save_checkpoint.remote(save_path, should_save_optimizer) for actor in self._actor_handlers])
class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic)
def async_calculate_value(self, sequences_attention_mask_action_mask_refs):
num_actors = len(self._actor_handlers)
value_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
value_ref = self._actor_handlers[i % num_actors].calculate_value.remote(
sequences_attention_mask_action_mask_refs[i])
value_refs.append(value_ref)
return value_refs
def set_loss_function(self, value_clip: float = 0.4):
ray.get([actor.set_loss_function.remote(value_clip) for actor in self._actor_handlers])
class PPOInitialRayActorGroup(PPORayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel)
def async_calculate_base_action_log_probs(self, sequences_attention_mask_action_mask_refs):
num_actors = len(self._actor_handlers)
base_action_log_probs_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote(
sequences_attention_mask_action_mask_refs[i])
base_action_log_probs_refs.append(base_action_log_probs_ref)
return base_action_log_probs_refs
class PPORewardRayActorGroup(PPORayActorGroup):
def __init__(self, num_nodes, num_gpus_per_node) -> None:
super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel)
def async_calculate_r(self, sequences_attention_mask_action_mask_refs):
num_actors = len(self._actor_handlers)
r_refs = []
for i in range(len(sequences_attention_mask_action_mask_refs)):
r_ref = self._actor_handlers[i % num_actors].calculate_r.remote(
sequences_attention_mask_action_mask_refs[i])
r_refs.append(r_ref)
return r_refs
def main(args):
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S')
if args.model == 'gpt2':
actor_model_class, critic_model_class = GPTActor, GPTCritic
elif args.model == 'bloom':
actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic
elif args.model == 'opt':
actor_model_class, critic_model_class = OPTActor, OPTCritic
else:
raise ValueError(f'Unsupported model "{args.model}"')
logging.info("Start creating actors")
# Initialize 4 models (actor, critic, initial_model and reward_model)
actor_group = PPOActorRayActorGroup(num_nodes=args.num_actor_nodes, num_gpus_per_node=args.num_gpus_per_node)
critic_group = PPOCriticRayActorGroup(num_nodes=args.num_critic_nodes, num_gpus_per_node=args.num_gpus_per_node)
initial_group = PPOInitialRayActorGroup(num_nodes=args.num_initial_nodes, num_gpus_per_node=args.num_gpus_per_node)
reward_group = PPORewardRayActorGroup(num_nodes=args.num_reward_nodes, num_gpus_per_node=args.num_gpus_per_node)
logging.info("Actors created")
# Prepare model for training
generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50}
ray.get(
actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) +
critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) +
initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) +
reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) +
actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs))
logging.info("Models prepared for training")
# Prepare models for training
actor_group.load_csv_prompt_file_from_url_to_sampler(args.prompt_csv_url)
actor_group.set_loss_function()
critic_group.set_loss_function()
# Training parameter
num_episodes = args.num_episodes
max_timesteps = args.max_timesteps
update_timesteps = args.update_timesteps
experience_batch_size = args.experience_batch_size
# Start training
logging.info("Training start")
# Set all models to eval and add experience maker
all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \
initial_group._actor_handlers + reward_group._actor_handlers
num_ray_actors = len(all_ray_actors)
ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors])
ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors])
# Used as a queue to coordinate experience making
experience_composition_refs = []
time = 0
for episode in range(num_episodes):
logging.info("episode {} started".format(episode))
for _ in range(max_timesteps):
time += 1
# Experience queueing stage
sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence(
experience_batch_size)
base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs(
sequences_attention_mask_action_mask_refs)
values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs)
r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs)
action_log_probs_refs = actor_group.async_calculate_action_log_probs(
sequences_attention_mask_action_mask_refs)
experience_composition_refs.extend([
ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i],
base_action_log_probs_refs[i], values_refs[i], r_refs[i])
for i in range(len(sequences_attention_mask_action_mask_refs))
])
# Learning stage
if time % update_timesteps == 0:
experience_refs = []
# calculate experiences
for i, experience_composition_ref in enumerate(experience_composition_refs):
exp_composition_ref = experience_composition_ref
selected_ray_actor = all_ray_actors[i % num_ray_actors]
experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref))
# backward
ray.get(
actor_group.async_learn_on_experiences(experience_refs) +
critic_group.async_learn_on_experiences(experience_refs))
# clear refs queue
experience_composition_refs.clear()
logging.info("Training finished")
# Save checkpoint
actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--prompt_csv_url', type=str)
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default='gpt2')
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8)
parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1)
parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1)
parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1)
parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1)
parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1)
args = parser.parse_args()
ray.init()
main(args)