You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/ColossalChat/coati/ray/experience_maker_holder.py

275 lines
11 KiB

import os
import time
import tracemalloc
from threading import Lock
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import ray
import torch
from coati.experience_buffer.utils import split_experience_batch
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models.base import Actor, Critic, RewardModel
from coati.trainer.strategies import Strategy
from torch import Tensor
from tqdm import tqdm
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
from .lora_constructor import LoRAConstructor
from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
class ExperienceMakerHolder:
"""
Args:
detached_trainer_name_list: str list to get ray actor handles
strategy:
kl_coef: the coefficient of kl divergence loss
sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
"""
def __init__(
self,
detached_trainer_name_list: List[str],
strategy_fn: Callable[[], Strategy],
# a function returns (actor, critic, reward_model, initial_model)
model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
env_info: Dict[str, str] = None,
sync_models_from_trainers: bool = False,
buffer_cpu_offload: bool = True,
kl_coef: float = 0.1,
callbacks: List[MakerCallback] = [],
eval_performance: bool = False,
debug: bool = False,
update_lora_weights: bool = False,
**generate_kwargs,
):
# set environment variables
if env_info:
set_dist_env(env_info=env_info)
self.target_trainer_list = []
assert len(detached_trainer_name_list) > 0
self._detached_trainer_name_list = detached_trainer_name_list
self.strategy = strategy_fn()
self.buffer_cpu_offload = buffer_cpu_offload
self.kl_coef = kl_coef
# init models
with self.strategy.model_init_context():
actor, critic, reward_model, initial_model = model_fn()
self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor)
if eval_performance:
actor_numel = get_model_numel(actor)
critic_numel = get_model_numel(critic)
initial_model_numel = get_model_numel(initial_model)
reward_model_numel = get_model_numel(reward_model)
evaluator = ExperienceMakerPerformanceEvaluator(
actor_numel, critic_numel, initial_model_numel, reward_model_numel
)
callbacks = callbacks + [evaluator]
actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
self.callbacks = callbacks
self._model_visit_lock = Lock()
self._is_fully_initialized = not sync_models_from_trainers
self._debug = debug
self._update_lora_weights = update_lora_weights
if self._update_lora_weights:
self.actor_lora_constructor = LoRAConstructor()
self.critic_lora_constructor = LoRAConstructor()
self.target_auto_balance = False
self._target_idx = 0
if self._debug:
print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
if not self._is_fully_initialized:
print(f"[maker{get_rank()}] Waiting for INIT")
def _get_ready(self):
while not self._fully_initialized():
time.sleep(1.0)
def _fully_initialized(self):
return self._is_fully_initialized
def _init_target_trainer_list(self):
if len(self.target_trainer_list) > 0:
return
for name in self._detached_trainer_name_list:
self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
# copy from ../trainer/base.py
@ray.method(concurrency_group="compute")
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
if isinstance(inputs, Tensor):
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
elif isinstance(inputs, dict):
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
else:
raise ValueError(f'Unsupported input type "{type(inputs)}"')
@ray.method(concurrency_group="experience_io")
def _send_items(self, experience: Experience) -> None:
self._init_target_trainer_list()
items = split_experience_batch(experience)
items_per_trainer = [[] for _ in range(len(self.target_trainer_list))]
for item in items:
items_per_trainer[self._target_idx].append(item)
self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
for i, target_trainer in enumerate(self.target_trainer_list):
if len(items_per_trainer[i]) > 0:
target_trainer.buffer_extend.remote(items_per_trainer[i])
def _inference_step(self, batch) -> None:
self._on_batch_start()
with self._model_visit_lock:
self._on_make_experience_start()
experience = self._make_experience(batch)
self._on_make_experience_end(experience)
self._on_send_start()
if self.buffer_cpu_offload:
experience.to_device("cpu")
self._send_items(experience)
self._on_send_end()
self._on_batch_end()
def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0):
"""Working loop of the experience maker.
Args:
dataloader_fn (Callable[[], Iterable]): A function that returns a dataloader.
num_epochs (int, optional): Iterate the dataloader for number of epochs. Defaults to 1.
num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0.
"""
self._get_ready()
self._on_loop_start()
dataloader = dataloader_fn()
if num_steps > 0:
# ignore num epochs
it = iter(dataloader)
for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
try:
batch = next(it)
except StopIteration:
it = iter(dataloader)
batch = next(it)
self._inference_step(batch)
else:
with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
for _ in range(num_epochs):
for batch in dataloader:
self._inference_step(batch)
pbar.update()
self._on_loop_end()
@ray.method(concurrency_group="model_io")
def update_experience_maker(
self,
new_actor_state_dict: Dict[str, Any] = None,
new_actor_lora_config_dict: Dict[str, Any] = None,
new_critic_state_dict: Dict[str, Any] = None,
new_critic_lora_config_dict: Dict[str, Any] = None,
fully_update: bool = False,
chunk_start: bool = None,
chunk_end: bool = None,
):
"""
called by trainer
chunk_start: Set True at the first call. Before sending state_dict calls
chunk_end: Set True at the last call. After sending state_dict calls.
fully_update: Set True if you want to sync models when initializing
TODO: load_state_dict integrate with model-sharding strategy
"""
_watch_memory = self._debug
if chunk_start:
if self._debug:
print("[maker] UPDATE ")
if _watch_memory:
tracemalloc.start()
self._model_visit_lock.acquire()
with torch.no_grad():
if new_actor_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
else:
new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
new_actor_state_dict, new_actor_lora_config_dict
)
self.actor_lora_constructor.load_state_dict_increase(
self.experience_maker.actor.model, state_dict_increase
)
if new_critic_state_dict is not None:
if not self._update_lora_weights or fully_update:
self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
else:
new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
new_critic_state_dict, new_critic_lora_config_dict
)
self.critic_lora_constructor.load_state_dict_increase(
self.experience_maker.critic, state_dict_increase
)
# the lock must be released after both actor and critic being updated
if chunk_end:
self._model_visit_lock.release()
if _watch_memory:
current, peak = tracemalloc.get_traced_memory()
print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
tracemalloc.stop()
if fully_update:
self._is_fully_initialized = True
def _on_make_experience_start(self) -> None:
for callback in self.callbacks:
callback.on_make_experience_start()
def _on_make_experience_end(self, experience: Experience) -> None:
for callback in self.callbacks:
callback.on_make_experience_end(experience)
def _on_loop_start(self) -> None:
for callback in self.callbacks:
callback.on_loop_start()
def _on_loop_end(self) -> None:
for callback in self.callbacks:
callback.on_loop_end()
def _on_send_start(self) -> None:
for callback in self.callbacks:
callback.on_send_start()
def _on_send_end(self) -> None:
for callback in self.callbacks:
callback.on_send_end()
def _on_batch_start(self) -> None:
for callback in self.callbacks:
callback.on_batch_start()
def _on_batch_end(self) -> None:
for callback in self.callbacks:
callback.on_batch_end()
def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
origin_model = actor.model
new_kwargs = {**generate_kwargs}
# use huggingface models method directly
if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
return new_kwargs