mirror of https://github.com/hpcaitech/ColossalAI
173 lines
7.6 KiB
Python
173 lines
7.6 KiB
Python
|
import torch
|
||
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||
|
import ray
|
||
|
from ray.exceptions import GetTimeoutError
|
||
|
from torch import Tensor
|
||
|
import torch.nn as nn
|
||
|
from coati.models.base import Actor, Critic, RewardModel
|
||
|
from coati.trainer.strategies.sampler import DistributedSampler
|
||
|
from coati.trainer.strategies import Strategy
|
||
|
from coati.experience_maker import NaiveExperienceMaker, Experience, ExperienceMaker
|
||
|
|
||
|
from copy import deepcopy
|
||
|
from threading import Lock
|
||
|
import time
|
||
|
import os
|
||
|
|
||
|
|
||
|
from .utils import is_rank_0, get_strategy_from_args, set_dist_env
|
||
|
|
||
|
|
||
|
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
|
||
|
class ExperienceMakerHolder:
|
||
|
'''
|
||
|
Args:
|
||
|
detached_trainer_name_list: str list to get ray actor handleskkk
|
||
|
strategy:
|
||
|
experience_batch_size: batch size of generated experience
|
||
|
kl_coef: the coefficient of kl divergence loss
|
||
|
'''
|
||
|
|
||
|
def __init__(self,
|
||
|
detached_trainer_name_list: List[str],
|
||
|
strategy: str,
|
||
|
env_info: Dict[str, str] = None,
|
||
|
experience_batch_size: int = 8,
|
||
|
kl_coef: float = 0.1,
|
||
|
**generate_kwargs):
|
||
|
# set environment variables
|
||
|
if env_info:
|
||
|
set_dist_env(env_info=env_info)
|
||
|
self.target_trainer_list = []
|
||
|
for name in detached_trainer_name_list:
|
||
|
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
|
||
|
self.strategy_str = strategy
|
||
|
self.strategy = get_strategy_from_args(strategy)
|
||
|
self.experience_batch_size = experience_batch_size
|
||
|
self.kl_coef = kl_coef
|
||
|
self.generate_kwargs = generate_kwargs
|
||
|
# Need a trainer to give an actor and a critic via initialize_experience_maker(...)
|
||
|
actor, critic, reward_model, initial_model = None, None, None, None
|
||
|
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
|
||
|
self._model_visit_lock = Lock()
|
||
|
self.fully_initialized = False
|
||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||
|
print('[maker] Waiting for INIT')
|
||
|
|
||
|
def _get_ready(self):
|
||
|
while not self.fully_initialized:
|
||
|
time.sleep(1.0)
|
||
|
|
||
|
def update_target_trainer_list(self, detached_trainer_name_list):
|
||
|
self.target_trainer_list = []
|
||
|
for name in detached_trainer_name_list:
|
||
|
self.target_trainer_list.append(ray.get_actor(name))
|
||
|
|
||
|
# copy from ../trainer/base.py
|
||
|
@ray.method(concurrency_group="compute")
|
||
|
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||
|
self._get_ready()
|
||
|
if isinstance(inputs, Tensor):
|
||
|
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||
|
elif isinstance(inputs, dict):
|
||
|
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
||
|
else:
|
||
|
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||
|
|
||
|
@ray.method(concurrency_group="experience_io")
|
||
|
def _send_experience(self, experience):
|
||
|
'''
|
||
|
ignore it
|
||
|
|
||
|
# choose a trainer that has the least experience batch in its detached_replay_buffer
|
||
|
chosen_trainer = None
|
||
|
min_length = None
|
||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||
|
print("[maker] choosing tartget trainer")
|
||
|
while chosen_trainer is None:
|
||
|
for target_trainer in self.target_trainer_list:
|
||
|
try:
|
||
|
temp_length = ray.get(target_trainer.buffer_get_length.remote(), timeout=0.1)
|
||
|
if min_length is None:
|
||
|
min_length = temp_length
|
||
|
chosen_trainer = target_trainer
|
||
|
else:
|
||
|
if temp_length < min_length:
|
||
|
min_length = temp_length
|
||
|
chosen_trainer = target_trainer
|
||
|
except GetTimeoutError:
|
||
|
pass
|
||
|
|
||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||
|
print(f"[maker] sending exp to {chosen_trainer}")
|
||
|
chosen_trainer.buffer_append.remote(experience)
|
||
|
'''
|
||
|
#
|
||
|
if not hasattr(self, "_target_idx"):
|
||
|
self._target_idx = 0
|
||
|
chosen_trainer = self.target_trainer_list[self._target_idx]
|
||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||
|
print(f"[maker] sending exp to {chosen_trainer}")
|
||
|
chosen_trainer.buffer_append.remote(experience)
|
||
|
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
|
||
|
|
||
|
def workingloop(self, dataset, tokenizer: Optional[Callable[[Any], dict]] = None, times=5000 * 50000):
|
||
|
self._get_ready()
|
||
|
sampler = self.strategy.setup_sampler(dataset)
|
||
|
for _ in range(times):
|
||
|
rand_prompts = sampler.sample(self.experience_batch_size)
|
||
|
if tokenizer is not None:
|
||
|
inputs = tokenizer(rand_prompts)
|
||
|
else:
|
||
|
inputs = rand_prompts
|
||
|
self._model_visit_lock.acquire()
|
||
|
experience = self._make_experience(inputs=inputs)
|
||
|
self._model_visit_lock.release()
|
||
|
self._send_experience(experience=experience)
|
||
|
|
||
|
@ray.method(concurrency_group="model_io")
|
||
|
def initialize_experience_maker(self, init_actor: Actor, init_critic: Critic):
|
||
|
'''
|
||
|
called by trainer. Only once.
|
||
|
'''
|
||
|
# TODO: reduce malloc
|
||
|
if self.fully_initialized:
|
||
|
return
|
||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||
|
print('[maker] INIT')
|
||
|
with torch.no_grad():
|
||
|
with self.strategy.model_init_context():
|
||
|
actor = init_actor
|
||
|
critic = init_critic
|
||
|
initial_model = deepcopy(actor)
|
||
|
reward_model = RewardModel(deepcopy(critic.model),
|
||
|
deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||
|
if self.strategy_str != 'colossalai_gemini':
|
||
|
actor.to(torch.float16).to(torch.cuda.current_device())
|
||
|
critic.to(torch.float16).to(torch.cuda.current_device())
|
||
|
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||
|
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||
|
|
||
|
self.experience_maker.actor = self.strategy.prepare(actor)
|
||
|
self.experience_maker.critic = self.strategy.prepare(critic)
|
||
|
self.experience_maker.initial_model = self.strategy.prepare(initial_model)
|
||
|
self.experience_maker.reward_model = self.strategy.prepare(reward_model)
|
||
|
self.fully_initialized = True
|
||
|
|
||
|
@ray.method(concurrency_group="model_io")
|
||
|
def update_experience_maker(self, new_actor: Actor, new_critic: Critic):
|
||
|
'''
|
||
|
called by trainer
|
||
|
'''
|
||
|
# TODO: reduce malloc
|
||
|
self._model_visit_lock.acquire()
|
||
|
with torch.no_grad():
|
||
|
if 'debug' in self.generate_kwargs and self.generate_kwargs['debug'] == True:
|
||
|
print("[maker] UPDATE ")
|
||
|
if self.strategy_str != 'colossalai_gemini':
|
||
|
new_actor.to(torch.float16).to(torch.cuda.current_device())
|
||
|
new_critic.to(torch.float16).to(torch.cuda.current_device())
|
||
|
self.experience_maker.actor = self.strategy.prepare(new_actor)
|
||
|
self.experience_maker.critic = self.strategy.prepare(new_critic)
|
||
|
self._model_visit_lock.release()
|