ColossalAI/applications/Chat/coati/experience_buffer/utils.py

75 lines
2.3 KiB
Python
Raw Normal View History

2023-03-28 12:25:36 +00:00
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn.functional as F
from coati.experience_maker.base import Experience
@dataclass
class BufferItem:
"""BufferItem is an item of experience data.
Shapes of each tensor:
sequences: (S)
action_log_probs: (A)
values: (1)
reward: (1)
2023-04-20 09:22:15 +00:00
advantages: (1)
2023-03-28 12:25:36 +00:00
attention_mask: (S)
action_mask: (A)
"A" is the number of actions.
"""
2023-03-28 12:25:36 +00:00
sequences: torch.Tensor
action_log_probs: torch.Tensor
values: torch.Tensor
reward: torch.Tensor
advantages: torch.Tensor
attention_mask: Optional[torch.LongTensor]
action_mask: Optional[torch.BoolTensor]
def split_experience_batch(experience: Experience) -> List[BufferItem]:
batch_size = experience.sequences.size(0)
batch_kwargs = [{} for _ in range(batch_size)]
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
2023-03-28 12:25:36 +00:00
for key in keys:
value = getattr(experience, key)
if isinstance(value, torch.Tensor):
vals = torch.unbind(value)
else:
# None
vals = [value for _ in range(batch_size)]
assert batch_size == len(vals)
for i, v in enumerate(vals):
batch_kwargs[i][key] = v
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
return items
def _zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left") -> torch.Tensor:
assert side in ("left", "right")
2023-03-28 12:25:36 +00:00
max_len = max(seq.size(0) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(0)
padding = (pad_len, 0) if side == "left" else (0, pad_len)
2023-03-28 12:25:36 +00:00
padded_sequences.append(F.pad(seq, padding))
return torch.stack(padded_sequences, dim=0)
def make_experience_batch(items: List[BufferItem]) -> Experience:
kwargs = {}
to_pad_keys = set(("action_log_probs", "action_mask"))
keys = ("sequences", "action_log_probs", "values", "reward", "advantages", "attention_mask", "action_mask")
2023-03-28 12:25:36 +00:00
for key in keys:
vals = [getattr(item, key) for item in items]
if key in to_pad_keys:
[chat] fix bugs and add unit tests (#4213) * style: rename replay buffer Experience replay is typically for off policy algorithms. Use this name in PPO maybe misleading. * fix: fix wrong zero2 default arg * test: update experience tests * style: rename zero_pad fn * fix: defer init in CycledDataLoader * test: add benchmark test * style: rename internal fn of generation * style: rename internal fn of lora * fix: remove unused loss fn * fix: remove unused utils fn * refactor: remove generate_with_actor fn * fix: fix type annotation * test: add models tests * fix: skip llama due to long execution time * style: modify dataset * style: apply formatter * perf: update reward dataset * fix: fix wrong IGNORE_INDEX in sft dataset * fix: remove DataCollatorForSupervisedDataset * test: add dataset tests * style: apply formatter * style: rename test_ci to test_train * feat: add llama in inference * test: add inference tests * test: change test scripts directory * fix: update ci * fix: fix typo * fix: skip llama due to oom * fix: fix file mod * style: apply formatter * refactor: remove duplicated llama_gptq * style: apply formatter * to: update rm test * feat: add tokenizer arg * feat: add download model script * test: update train tests * fix: modify gemini load and save pretrained * test: update checkpoint io test * to: modify nproc_per_node * fix: do not remove existing dir * fix: modify save path * test: add random choice * fix: fix sft path * fix: enlarge nproc_per_node to avoid oom * fix: add num_retry * fix: make lora config of rm and critic consistent * fix: add warning about lora weights * fix: skip some gpt2 tests * fix: remove grad ckpt in rm and critic due to errors * refactor: directly use Actor in train_sft * test: add more arguments * fix: disable grad ckpt when using lora * fix: fix save_pretrained and related tests * test: enable zero2 tests * revert: remove useless fn * style: polish code * test: modify test args
2023-08-02 02:17:36 +00:00
batch_data = _zero_pad_sequences(vals)
2023-03-28 12:25:36 +00:00
else:
batch_data = torch.stack(vals, dim=0)
kwargs[key] = batch_data
return Experience(**kwargs)