ColossalAI/applications/Chat/coati/ray/src/experience_maker_holder.py

173 lines
7.6 KiB
Python

import torch
from typing import Any, Callable, Dict, List, Optional, Union
import ray
from ray.exceptions import GetTimeoutError
from torch import Tensor
import torch.nn as nn
from coati.models.base import Actor, Critic, RewardModel
from coati.trainer.strategies.sampler import DistributedSampler
from coati.trainer.strategies import Strategy
from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker
from copy import deepcopy
from threading import Lock
import time
import os
from .utils import is_rank_0, get_strategy_from_args, set_dist_env
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
'''
Args:
detached_trainer_name_list: str list to get ray actor handleskkk
strategy:
experience_batch_size: batch size of generated experience
kl_coef: the coefficient of kl divergence loss
'''
def __init__(self,
detached_trainer_name_list: List[str],
strategy: str,
env_info: Dict[str, str] = None,
experience_batch_size: int = 8,
kl_coef: float = 0.1,
**generate_kwargs):
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
self.target_trainer_list = []
for name in detached_trainer_name_list:
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
self.strategy_str = strategy
self.strategy = get_strategy_from_args(strategy)
self.experience_batch_size = experience_batch_size
self.kl_coef = kl_coef
self.generate_kwargs = generate_kwargs
# Need a trainer to give an actor and a critic via initialize_experience_maker(...)
actor, critic, reward_model, initial_model = None, None, None, None
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
self._model_visit_lock = Lock()
self.fully_initialized = False
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print('[maker] Waiting for INIT')
def _get_ready(self):
while not self.fully_initialized:
time.sleep(1.0)
def update_target_trainer_list(self, detached_trainer_name_list):
self.target_trainer_list = []
for name in detached_trainer_name_list:
self.target_trainer_list.append(ray.get_actor(name))
# copy from ../trainer/base.py
@ray.method(concurrency_group="compute")
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
self._get_ready()
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
@ray.method(concurrency_group="experience_io")
def _send_experience(self, experience):
'''
ignore it
# choose a trainer that has the least experience batch in its detached_replay_buffer
chosen_trainer = None
min_length = None
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print("[maker] choosing tartget trainer")
while chosen_trainer is None:
for target_trainer in self.target_trainer_list:
try:
temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1)
if min_length is None:
min_length = temp_length
chosen_trainer = target_trainer
else:
if temp_length < min_length:
min_length = temp_length
chosen_trainer = target_trainer
except GetTimeoutError:
pass
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print(f"[maker] sending exp to {chosen_trainer}")
chosen_trainer.buffer_append.remote(experience)
'''
#
if not hasattr(self, "_target_idx"):
self._target_idx = 0
chosen_trainer = self.target_trainer_list[self._target_idx]
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print(f"[maker] sending exp to {chosen_trainer}")
chosen_trainer.buffer_append.remote(experience)
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000):
self._get_ready()
sampler = self.strategy.setup_sampler(dataset)
for _ in range(times):
rand_prompts = sampler.sample(self.experience_batch_size)
if tokenizer is not None:
inputs = tokenizer(rand_prompts)
else:
inputs = rand_prompts
self._model_visit_lock.acquire()
experience = self._make_experience(inputs=inputs)
self._model_visit_lock.release()
self._send_experience(experience=experience)
@ray.method(concurrency_group="model_io")
def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic):
'''
called by trainer. Only once.
'''
# TODO: reduce malloc
if self.fully_initialized:
return
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print('[maker] INIT')
with torch.no_grad():
with self.strategy.model_init_context():
actor = init_actor
critic = init_critic
initial_model = deepcopy(actor)
reward_model = RewardModel(deepcopy(critic.model),
deepcopy(critic.value_head)).to(torch.cuda.current_device())
if self.strategy_str != 'colossalai_gemini':
actor.to(torch.float16).to(torch.cuda.current_device())
critic.to(torch.float16).to(torch.cuda.current_device())
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
self.experience_maker.actor = self.strategy.prepare(actor)
self.experience_maker.critic = self.strategy.prepare(critic)
self.experience_maker.initial_model = self.strategy.prepare(initial_model)
self.experience_maker.reward_model = self.strategy.prepare(reward_model)
self.fully_initialized = True
@ray.method(concurrency_group="model_io")
def update_experience_maker(self, new_actor: Actor, new_critic: Critic):
'''
called by trainer
'''
# TODO: reduce malloc
self._model_visit_lock.acquire()
with torch.no_grad():
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
print("[maker] UPDATE ")
if self.strategy_str != 'colossalai_gemini':
new_actor.to(torch.float16).to(torch.cuda.current_device())
new_critic.to(torch.float16).to(torch.cuda.current_device())
self.experience_maker.actor = self.strategy.prepare(new_actor)
self.experience_maker.critic = self.strategy.prepare(new_critic)
self._model_visit_lock.release()