mirror of https://github.com/hpcaitech/ColossalAI
556 lines
26 KiB
Python
556 lines
26 KiB
Python
|
import argparse
|
||
|
import logging
|
||
|
import os
|
||
|
import socket
|
||
|
from copy import deepcopy
|
||
|
from typing import Type
|
||
|
|
||
|
import ray
|
||
|
import torch
|
||
|
from coati.experience_maker.base import Experience
|
||
|
from coati.models.base import RewardModel
|
||
|
from coati.models.bloom import BLOOMActor, BLOOMCritic
|
||
|
from coati.models.gpt import GPTActor, GPTCritic
|
||
|
from coati.models.lora import LoRAModule
|
||
|
from coati.models.loss import PolicyLoss, ValueLoss
|
||
|
from coati.models.opt import OPTActor, OPTCritic
|
||
|
from coati.models.utils import compute_reward
|
||
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||
|
from ray.util.placement_group import placement_group
|
||
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||
|
from torch.optim import Adam
|
||
|
from transformers import AutoTokenizer, BloomTokenizerFast
|
||
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||
|
|
||
|
from colossalai.nn.optimizer import HybridAdam
|
||
|
|
||
|
|
||
|
class ExperienceCompositionRefs:
|
||
|
|
||
|
def __init__(self, sequences_attention_mask_action_mask_ref: ray.ObjectRef, action_log_probs_ref: ray.ObjectRef,
|
||
|
base_action_log_probs_ref: ray.ObjectRef, value_ref: ray.ObjectRef, r_ref: ray.ObjectRef) -> None:
|
||
|
self.sequences_attention_mask_action_mask_ref = sequences_attention_mask_action_mask_ref
|
||
|
self.action_log_probs_ref = action_log_probs_ref
|
||
|
self.base_action_log_probs_ref = base_action_log_probs_ref
|
||
|
self.value_ref = value_ref
|
||
|
self.r_ref = r_ref
|
||
|
|
||
|
|
||
|
class ExperienceMaker:
|
||
|
|
||
|
def __init__(self, kl_coef) -> None:
|
||
|
self.kl_coef = kl_coef
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs):
|
||
|
sequences, attention_mask, action_mask = ray.get(
|
||
|
experiment_computation_refs.sequences_attention_mask_action_mask_ref)
|
||
|
action_log_probs = ray.get(experiment_computation_refs.action_log_probs_ref)
|
||
|
base_action_log_probs = ray.get(experiment_computation_refs.base_action_log_probs_ref)
|
||
|
r = ray.get(experiment_computation_refs.r_ref)
|
||
|
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
||
|
value = ray.get(experiment_computation_refs.value_ref)
|
||
|
advantage = reward - value
|
||
|
if advantage.ndim == 1:
|
||
|
advantage = advantage.unsqueeze(-1)
|
||
|
experience = Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask)
|
||
|
return experience
|
||
|
|
||
|
|
||
|
class DistributedTorchRayActor:
|
||
|
|
||
|
def __init__(self, world_size, rank, local_rank, master_addr, master_port):
|
||
|
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
|
||
|
level=logging.INFO,
|
||
|
datefmt='%Y-%m-%d %H:%M:%S')
|
||
|
self._model = None
|
||
|
self._world_size = world_size
|
||
|
self._rank = rank
|
||
|
self._local_rank = local_rank
|
||
|
self._master_addr = master_addr if master_addr else self._get_current_node_ip()
|
||
|
self._master_port = master_port if master_port else self._get_free_port()
|
||
|
os.environ["MASTER_ADDR"] = self._master_addr
|
||
|
os.environ["MASTER_PORT"] = str(self._master_port)
|
||
|
os.environ["WORLD_SIZE"] = str(self._world_size)
|
||
|
os.environ["RANK"] = str(self._rank)
|
||
|
os.environ["LOCAL_RANK"] = str(self._local_rank)
|
||
|
|
||
|
@staticmethod
|
||
|
def _get_current_node_ip():
|
||
|
return ray._private.services.get_node_ip_address()
|
||
|
|
||
|
@staticmethod
|
||
|
def _get_free_port():
|
||
|
with socket.socket() as sock:
|
||
|
sock.bind(('', 0))
|
||
|
return sock.getsockname()[1]
|
||
|
|
||
|
def get_master_addr_port(self):
|
||
|
return self._master_addr, self._master_port
|
||
|
|
||
|
|
||
|
class BasePPORole(DistributedTorchRayActor):
|
||
|
|
||
|
def add_experience_maker(self, kl_coef: float = 0.1):
|
||
|
self._experience_maker = ExperienceMaker(kl_coef)
|
||
|
|
||
|
def make_experience(self, experience_computation_ref: ExperienceCompositionRefs):
|
||
|
return self._experience_maker.make_experience(experience_computation_ref)
|
||
|
|
||
|
def _init_strategy(self, strategy: str):
|
||
|
# configure strategy
|
||
|
if strategy == 'naive':
|
||
|
self._strategy = NaiveStrategy()
|
||
|
elif strategy == 'ddp':
|
||
|
self._strategy = DDPStrategy()
|
||
|
elif strategy == 'colossalai_gemini':
|
||
|
self._strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||
|
elif strategy == 'colossalai_zero2':
|
||
|
self._strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||
|
else:
|
||
|
raise ValueError(f'Unsupported strategy "{strategy}"')
|
||
|
|
||
|
def _init_optimizer(self):
|
||
|
if isinstance(self._strategy, ColossalAIStrategy):
|
||
|
self._optimizer = HybridAdam(self._model.parameters(), lr=5e-6)
|
||
|
else:
|
||
|
self._optimizer = Adam(self._model.parameters(), lr=5e-6)
|
||
|
|
||
|
def _prepare_model_with_strategy(self, has_optimizer: bool):
|
||
|
if has_optimizer:
|
||
|
self._init_optimizer()
|
||
|
(self._model, self._optimizer) = self._strategy.prepare((self._model, self._optimizer))
|
||
|
else:
|
||
|
self._model = self._strategy.prepare(self._model)
|
||
|
|
||
|
def _load_model_from_pretrained(self, model_class: Type[LoRAModule], pretrain: str):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def init_model_from_pretrained(self,
|
||
|
strategy: str,
|
||
|
model_class: Type[LoRAModule],
|
||
|
pretrain: str,
|
||
|
has_optimizer=False):
|
||
|
self._init_strategy(strategy)
|
||
|
self._load_model_from_pretrained(model_class, pretrain)
|
||
|
self._prepare_model_with_strategy(has_optimizer)
|
||
|
|
||
|
def eval(self):
|
||
|
self._model.eval()
|
||
|
|
||
|
|
||
|
class TrainablePPORole(BasePPORole):
|
||
|
|
||
|
def _load_model_from_pretrained(self, model_class, pretrain):
|
||
|
with self._strategy.model_init_context():
|
||
|
self._model = model_class(pretrain).to(torch.cuda.current_device())
|
||
|
|
||
|
def _train(self):
|
||
|
self._model.train()
|
||
|
|
||
|
def _training_step(self, experience: Experience):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def learn_on_experiences(self, experience_refs):
|
||
|
experiences = ray.get(experience_refs)
|
||
|
device = torch.cuda.current_device()
|
||
|
self._train()
|
||
|
for exp in experiences:
|
||
|
exp.to_device(device)
|
||
|
self._training_step(exp)
|
||
|
self.eval()
|
||
|
|
||
|
|
||
|
@ray.remote(num_gpus=1)
|
||
|
class RayPPOActor(TrainablePPORole):
|
||
|
|
||
|
def set_loss_function(self, eps_clip: float):
|
||
|
self._actor_loss_fn = PolicyLoss(eps_clip)
|
||
|
|
||
|
def load_tokenizer_from_pretrained(self, model_type: str, pretrained):
|
||
|
if model_type == 'gpt2':
|
||
|
self._model_tokenizer = GPT2Tokenizer.from_pretrained(pretrained)
|
||
|
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
|
||
|
elif model_type == 'bloom':
|
||
|
self._model_tokenizer = BloomTokenizerFast.from_pretrained(pretrained)
|
||
|
self._model_tokenizer.pad_token = self._model_tokenizer.eos_token
|
||
|
elif model_type == 'opt':
|
||
|
self._model_tokenizer = AutoTokenizer.from_pretrained(pretrained)
|
||
|
else:
|
||
|
raise ValueError(f'Unsupported model "{model_type}"')
|
||
|
|
||
|
# Set tokenize function for sequence generation
|
||
|
def _text_input_tokenize_fn(texts):
|
||
|
batch = self._model_tokenizer(texts, return_tensors='pt', max_length=96, padding=True, truncation=True)
|
||
|
return {k: v.cuda() for k, v in batch.items()}
|
||
|
|
||
|
self._sample_tokenize_function = _text_input_tokenize_fn
|
||
|
|
||
|
def setup_generate_kwargs(self, generate_kwargs: dict):
|
||
|
from coati.trainer.ppo import _set_default_generate_kwargs
|
||
|
self._generate_kwargs = _set_default_generate_kwargs(self._strategy, generate_kwargs, self._model)
|
||
|
self._generate_kwargs['pad_token_id'] = self._model_tokenizer.pad_token_id
|
||
|
self._generate_kwargs['eos_token_id'] = self._model_tokenizer.eos_token_id
|
||
|
|
||
|
def load_csv_prompt_file_from_url_to_sampler(self, prompt_url):
|
||
|
import pandas as pd
|
||
|
prompts = pd.read_csv(prompt_url)['prompt']
|
||
|
self._sampler = self._strategy.setup_sampler(prompts)
|
||
|
|
||
|
def _generate(self, input_ids, **generate_kwargs):
|
||
|
return self._model.generate(input_ids, return_action_mask=True, **generate_kwargs)
|
||
|
|
||
|
def sample_prompts_and_make_sequence(self, experience_batch_size):
|
||
|
sampled_prompts = self._sampler.sample(experience_batch_size)
|
||
|
input_ids = self._sample_tokenize_function(sampled_prompts)
|
||
|
if isinstance(input_ids, dict):
|
||
|
return self._generate(**input_ids, **self._generate_kwargs)
|
||
|
else:
|
||
|
return self._generate(input_ids, **self._generate_kwargs)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def calculate_action_log_probs(self, sequence_attention_action_mask):
|
||
|
sequences, attention_mask, action_mask = sequence_attention_action_mask
|
||
|
return self._model.forward(sequences, action_mask.size(1), attention_mask)
|
||
|
|
||
|
def _training_step(self, experience):
|
||
|
num_actions = experience.action_mask.size(1)
|
||
|
action_log_probs = self._model(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||
|
actor_loss = self._actor_loss_fn(action_log_probs,
|
||
|
experience.action_log_probs,
|
||
|
experience.advantages,
|
||
|
action_mask=experience.action_mask)
|
||
|
self._strategy.backward(actor_loss, self._model, self._optimizer)
|
||
|
self._strategy.optimizer_step(self._optimizer)
|
||
|
self._optimizer.zero_grad()
|
||
|
logging.info("actor_loss: {}".format(actor_loss))
|
||
|
|
||
|
def save_checkpoint(self, save_path, should_save_optimizer: bool):
|
||
|
if self._rank == 0:
|
||
|
# save model checkpoint only on rank 0
|
||
|
self._strategy.save_model(self._model, save_path, only_rank0=True)
|
||
|
# save optimizer checkpoint on all ranks
|
||
|
if should_save_optimizer:
|
||
|
self._strategy.save_optimizer(self._optimizer,
|
||
|
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||
|
only_rank0=False)
|
||
|
|
||
|
def generate_answer(self, prompt, max_length=30, num_return_sequences=5):
|
||
|
encoded_input = self._model_tokenizer(prompt, return_tensors='pt')
|
||
|
input_ids = {k: v.cuda() for k, v in encoded_input.items()}
|
||
|
sequence, _ = self._model.generate(**input_ids,
|
||
|
max_length=max_length,
|
||
|
return_action_mask=False,
|
||
|
num_return_sequences=num_return_sequences)
|
||
|
token_list = list(sequence.data[0])
|
||
|
output = " ".join([self._model_tokenizer.decode(token) for token in token_list])
|
||
|
return output
|
||
|
|
||
|
|
||
|
@ray.remote(num_gpus=1)
|
||
|
class RayPPOCritic(TrainablePPORole):
|
||
|
|
||
|
def set_loss_function(self, value_clip: float):
|
||
|
self._critic_loss_fn = ValueLoss(value_clip)
|
||
|
|
||
|
def _training_step(self, experience):
|
||
|
values = self._model(experience.sequences,
|
||
|
action_mask=experience.action_mask,
|
||
|
attention_mask=experience.attention_mask)
|
||
|
critic_loss = self._critic_loss_fn(values,
|
||
|
experience.values,
|
||
|
experience.reward,
|
||
|
action_mask=experience.action_mask)
|
||
|
self._strategy.backward(critic_loss, self._model, self._optimizer)
|
||
|
self._strategy.optimizer_step(self._optimizer)
|
||
|
self._optimizer.zero_grad()
|
||
|
logging.info("critic_loss: {}".format(critic_loss))
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def calculate_value(self, sequence_attention_action_mask):
|
||
|
sequences, attention_mask, action_mask = sequence_attention_action_mask
|
||
|
return self._model(sequences, action_mask, attention_mask)
|
||
|
|
||
|
|
||
|
@ray.remote(num_gpus=1)
|
||
|
class RayPPORewardModel(BasePPORole):
|
||
|
|
||
|
def _load_model_from_pretrained(self, model_class, pretrain):
|
||
|
with self._strategy.model_init_context():
|
||
|
critic = model_class(pretrained=pretrain).to(torch.cuda.current_device())
|
||
|
self._model = RewardModel(deepcopy(critic.model),
|
||
|
deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def calculate_r(self, sequence_attention_action_mask):
|
||
|
sequences, attention_mask, _ = sequence_attention_action_mask
|
||
|
return self._model(sequences, attention_mask)
|
||
|
|
||
|
|
||
|
@ray.remote(num_gpus=1)
|
||
|
class RayPPOInitialModel(BasePPORole):
|
||
|
|
||
|
def _load_model_from_pretrained(self, model_class, pretrain):
|
||
|
with self._strategy.model_init_context():
|
||
|
self._model = model_class(pretrain).to(torch.cuda.current_device())
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def calculate_base_action_log_probs(self, sequence_attention_action_mask):
|
||
|
sequences, attention_mask, action_mask = sequence_attention_action_mask
|
||
|
return self._model(sequences, action_mask.size(1), attention_mask)
|
||
|
|
||
|
|
||
|
class PPORayActorGroup:
|
||
|
"""
|
||
|
A group of ray actors
|
||
|
Functions start with 'async' should return list of object refs
|
||
|
"""
|
||
|
|
||
|
def __init__(self, num_nodes, num_gpus_per_node, ray_actor_type: Type[BasePPORole]) -> None:
|
||
|
self._num_nodes = num_nodes
|
||
|
self._num_gpus_per_node = num_gpus_per_node
|
||
|
self.ray_actor_type = ray_actor_type
|
||
|
self._initiate_actors()
|
||
|
|
||
|
def _initiate_actors(self):
|
||
|
world_size = self._num_nodes * self._num_gpus_per_node
|
||
|
# Use placement group to lock resources for models of same type
|
||
|
pg = None
|
||
|
if self._num_gpus_per_node > 1:
|
||
|
bundles = [{"GPU": self._num_gpus_per_node, "CPU": self._num_gpus_per_node} for _ in range(self._num_nodes)]
|
||
|
pg = placement_group(bundles, strategy="STRICT_SPREAD")
|
||
|
ray.get(pg.ready())
|
||
|
if pg:
|
||
|
master_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||
|
placement_group=pg, placement_group_bundle_index=0)).remote(world_size, 0, 0, None, None)
|
||
|
else:
|
||
|
master_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, 0, 0, None, None)
|
||
|
self._actor_handlers = [master_actor]
|
||
|
|
||
|
# Create worker actors
|
||
|
if world_size > 1:
|
||
|
master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote())
|
||
|
for rank in range(1, world_size):
|
||
|
local_rank = rank % self._num_gpus_per_node
|
||
|
if pg:
|
||
|
worker_actor = self.ray_actor_type.options(scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||
|
placement_group=pg, placement_group_bundle_index=rank // self._num_gpus_per_node)).remote(
|
||
|
world_size, rank, local_rank, master_addr, master_port)
|
||
|
else:
|
||
|
worker_actor = self.ray_actor_type.options(num_gpus=1).remote(world_size, rank, local_rank,
|
||
|
master_addr, master_port)
|
||
|
self._actor_handlers.append(worker_actor)
|
||
|
|
||
|
def async_init_model_from_pretrained(self, strategy: str, model_class: Type[LoRAModule], pretrain: str,
|
||
|
has_optimizer: bool):
|
||
|
return [
|
||
|
actor.init_model_from_pretrained.remote(strategy, model_class, pretrain, has_optimizer)
|
||
|
for actor in self._actor_handlers
|
||
|
]
|
||
|
|
||
|
|
||
|
class TrainableModelRayActorGroup(PPORayActorGroup):
|
||
|
|
||
|
def async_learn_on_experiences(self, experience_refs):
|
||
|
num_actors = len(self._actor_handlers)
|
||
|
learn_result_refs = []
|
||
|
for i in range(num_actors):
|
||
|
exp_refs_batch = experience_refs[i::num_actors]
|
||
|
learn_result_refs.append(self._actor_handlers[i].learn_on_experiences.remote(exp_refs_batch))
|
||
|
return learn_result_refs
|
||
|
|
||
|
|
||
|
class PPOActorRayActorGroup(TrainableModelRayActorGroup):
|
||
|
|
||
|
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||
|
super().__init__(num_nodes, num_gpus_per_node, RayPPOActor)
|
||
|
|
||
|
def async_prepare_for_sequence_generation(self, model: str, pretrain: str, generation_kwargs: dict):
|
||
|
refs = []
|
||
|
for actor in self._actor_handlers:
|
||
|
refs.append(actor.load_tokenizer_from_pretrained.remote(model, pretrain))
|
||
|
refs.append(actor.setup_generate_kwargs.remote(generation_kwargs))
|
||
|
return refs
|
||
|
|
||
|
def load_csv_prompt_file_from_url_to_sampler(self, csv_url):
|
||
|
ray.get([actor.load_csv_prompt_file_from_url_to_sampler.remote(csv_url) for actor in self._actor_handlers])
|
||
|
|
||
|
def async_sample_prompts_and_make_sequence(self, experience_batch_size):
|
||
|
return [actor.sample_prompts_and_make_sequence.remote(experience_batch_size) for actor in self._actor_handlers]
|
||
|
|
||
|
def async_calculate_action_log_probs(self, sequences_attention_mask_action_mask_refs):
|
||
|
num_actors = len(self._actor_handlers)
|
||
|
action_log_probs_refs = []
|
||
|
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||
|
action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_action_log_probs.remote(
|
||
|
sequences_attention_mask_action_mask_refs[i])
|
||
|
action_log_probs_refs.append(action_log_probs_ref)
|
||
|
return action_log_probs_refs
|
||
|
|
||
|
def set_loss_function(self, eps_clip: float = 0.2):
|
||
|
ray.get([actor.set_loss_function.remote(eps_clip) for actor in self._actor_handlers])
|
||
|
|
||
|
def save_checkpoint(self, save_path, should_save_optimizer):
|
||
|
ray.get([actor.save_checkpoint.remote(save_path, should_save_optimizer) for actor in self._actor_handlers])
|
||
|
|
||
|
|
||
|
class PPOCriticRayActorGroup(TrainableModelRayActorGroup):
|
||
|
|
||
|
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||
|
super().__init__(num_nodes, num_gpus_per_node, RayPPOCritic)
|
||
|
|
||
|
def async_calculate_value(self, sequences_attention_mask_action_mask_refs):
|
||
|
num_actors = len(self._actor_handlers)
|
||
|
value_refs = []
|
||
|
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||
|
value_ref = self._actor_handlers[i % num_actors].calculate_value.remote(
|
||
|
sequences_attention_mask_action_mask_refs[i])
|
||
|
value_refs.append(value_ref)
|
||
|
return value_refs
|
||
|
|
||
|
def set_loss_function(self, value_clip: float = 0.4):
|
||
|
ray.get([actor.set_loss_function.remote(value_clip) for actor in self._actor_handlers])
|
||
|
|
||
|
|
||
|
class PPOInitialRayActorGroup(PPORayActorGroup):
|
||
|
|
||
|
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||
|
super().__init__(num_nodes, num_gpus_per_node, RayPPOInitialModel)
|
||
|
|
||
|
def async_calculate_base_action_log_probs(self, sequences_attention_mask_action_mask_refs):
|
||
|
num_actors = len(self._actor_handlers)
|
||
|
base_action_log_probs_refs = []
|
||
|
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||
|
base_action_log_probs_ref = self._actor_handlers[i % num_actors].calculate_base_action_log_probs.remote(
|
||
|
sequences_attention_mask_action_mask_refs[i])
|
||
|
base_action_log_probs_refs.append(base_action_log_probs_ref)
|
||
|
return base_action_log_probs_refs
|
||
|
|
||
|
|
||
|
class PPORewardRayActorGroup(PPORayActorGroup):
|
||
|
|
||
|
def __init__(self, num_nodes, num_gpus_per_node) -> None:
|
||
|
super().__init__(num_nodes, num_gpus_per_node, RayPPORewardModel)
|
||
|
|
||
|
def async_calculate_r(self, sequences_attention_mask_action_mask_refs):
|
||
|
num_actors = len(self._actor_handlers)
|
||
|
r_refs = []
|
||
|
for i in range(len(sequences_attention_mask_action_mask_refs)):
|
||
|
r_ref = self._actor_handlers[i % num_actors].calculate_r.remote(
|
||
|
sequences_attention_mask_action_mask_refs[i])
|
||
|
r_refs.append(r_ref)
|
||
|
return r_refs
|
||
|
|
||
|
|
||
|
def main(args):
|
||
|
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s',
|
||
|
level=logging.INFO,
|
||
|
datefmt='%Y-%m-%d %H:%M:%S')
|
||
|
if args.model == 'gpt2':
|
||
|
actor_model_class, critic_model_class = GPTActor, GPTCritic
|
||
|
elif args.model == 'bloom':
|
||
|
actor_model_class, critic_model_class = BLOOMActor, BLOOMCritic
|
||
|
elif args.model == 'opt':
|
||
|
actor_model_class, critic_model_class = OPTActor, OPTCritic
|
||
|
else:
|
||
|
raise ValueError(f'Unsupported model "{args.model}"')
|
||
|
|
||
|
logging.info("Start creating actors")
|
||
|
# Initialize 4 models (actor, critic, initial_model and reward_model)
|
||
|
actor_group = PPOActorRayActorGroup(num_nodes=args.num_actor_nodes, num_gpus_per_node=args.num_gpus_per_node)
|
||
|
critic_group = PPOCriticRayActorGroup(num_nodes=args.num_critic_nodes, num_gpus_per_node=args.num_gpus_per_node)
|
||
|
initial_group = PPOInitialRayActorGroup(num_nodes=args.num_initial_nodes, num_gpus_per_node=args.num_gpus_per_node)
|
||
|
reward_group = PPORewardRayActorGroup(num_nodes=args.num_reward_nodes, num_gpus_per_node=args.num_gpus_per_node)
|
||
|
logging.info("Actors created")
|
||
|
|
||
|
# Prepare model for training
|
||
|
generate_kwargs = {'max_length': 128, 'do_sample': True, 'temperature': 1.0, 'top_k': 50}
|
||
|
ray.get(
|
||
|
actor_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, True) +
|
||
|
critic_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, True) +
|
||
|
initial_group.async_init_model_from_pretrained(args.strategy, actor_model_class, args.pretrain, False) +
|
||
|
reward_group.async_init_model_from_pretrained(args.strategy, critic_model_class, args.pretrain, False) +
|
||
|
actor_group.async_prepare_for_sequence_generation(args.model, args.pretrain, generate_kwargs))
|
||
|
logging.info("Models prepared for training")
|
||
|
|
||
|
# Prepare models for training
|
||
|
actor_group.load_csv_prompt_file_from_url_to_sampler(args.prompt_csv_url)
|
||
|
actor_group.set_loss_function()
|
||
|
critic_group.set_loss_function()
|
||
|
# Training parameter
|
||
|
num_episodes = args.num_episodes
|
||
|
max_timesteps = args.max_timesteps
|
||
|
update_timesteps = args.update_timesteps
|
||
|
experience_batch_size = args.experience_batch_size
|
||
|
# Start training
|
||
|
logging.info("Training start")
|
||
|
# Set all models to eval and add experience maker
|
||
|
all_ray_actors = actor_group._actor_handlers + critic_group._actor_handlers + \
|
||
|
initial_group._actor_handlers + reward_group._actor_handlers
|
||
|
num_ray_actors = len(all_ray_actors)
|
||
|
ray.get([ray_actor.eval.remote() for ray_actor in all_ray_actors])
|
||
|
ray.get([ray_actor.add_experience_maker.remote() for ray_actor in all_ray_actors])
|
||
|
# Used as a queue to coordinate experience making
|
||
|
experience_composition_refs = []
|
||
|
time = 0
|
||
|
for episode in range(num_episodes):
|
||
|
logging.info("episode {} started".format(episode))
|
||
|
for _ in range(max_timesteps):
|
||
|
time += 1
|
||
|
# Experience queueing stage
|
||
|
sequences_attention_mask_action_mask_refs = actor_group.async_sample_prompts_and_make_sequence(
|
||
|
experience_batch_size)
|
||
|
base_action_log_probs_refs = initial_group.async_calculate_base_action_log_probs(
|
||
|
sequences_attention_mask_action_mask_refs)
|
||
|
values_refs = critic_group.async_calculate_value(sequences_attention_mask_action_mask_refs)
|
||
|
r_refs = reward_group.async_calculate_r(sequences_attention_mask_action_mask_refs)
|
||
|
action_log_probs_refs = actor_group.async_calculate_action_log_probs(
|
||
|
sequences_attention_mask_action_mask_refs)
|
||
|
experience_composition_refs.extend([
|
||
|
ExperienceCompositionRefs(sequences_attention_mask_action_mask_refs[i], action_log_probs_refs[i],
|
||
|
base_action_log_probs_refs[i], values_refs[i], r_refs[i])
|
||
|
for i in range(len(sequences_attention_mask_action_mask_refs))
|
||
|
])
|
||
|
# Learning stage
|
||
|
if time % update_timesteps == 0:
|
||
|
experience_refs = []
|
||
|
# calculate experiences
|
||
|
for i, experience_composition_ref in enumerate(experience_composition_refs):
|
||
|
exp_composition_ref = experience_composition_ref
|
||
|
selected_ray_actor = all_ray_actors[i % num_ray_actors]
|
||
|
experience_refs.append(selected_ray_actor.make_experience.remote(exp_composition_ref))
|
||
|
# backward
|
||
|
ray.get(
|
||
|
actor_group.async_learn_on_experiences(experience_refs) +
|
||
|
critic_group.async_learn_on_experiences(experience_refs))
|
||
|
# clear refs queue
|
||
|
experience_composition_refs.clear()
|
||
|
logging.info("Training finished")
|
||
|
# Save checkpoint
|
||
|
actor_group.save_checkpoint(args.save_path, args.need_optim_ckpt)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--prompt_csv_url', type=str)
|
||
|
parser.add_argument('--strategy',
|
||
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||
|
default='naive')
|
||
|
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||
|
parser.add_argument('--pretrain', type=str, default='gpt2')
|
||
|
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
|
||
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||
|
parser.add_argument('--num_episodes', type=int, default=10)
|
||
|
parser.add_argument('--max_timesteps', type=int, default=10)
|
||
|
parser.add_argument('--update_timesteps', type=int, default=10)
|
||
|
parser.add_argument('--train_batch_size', type=int, default=8)
|
||
|
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||
|
parser.add_argument('--num_actor_nodes', type=int, help='num of nodes to use to host actor model', default=1)
|
||
|
parser.add_argument('--num_critic_nodes', type=int, help='num of nodes to use to host critic model', default=1)
|
||
|
parser.add_argument('--num_initial_nodes', type=int, help='num of nodes to use to host initial model', default=1)
|
||
|
parser.add_argument('--num_reward_nodes', type=int, help='num of nodes to use to host reward model', default=1)
|
||
|
parser.add_argument('--num_gpus_per_node', type=int, help='num of gpus on a ray node', default=1)
|
||
|
args = parser.parse_args()
|
||
|
ray.init()
|
||
|
main(args)
|