|
|
@ -25,12 +25,13 @@ class SLTrainer(ABC): |
|
|
|
optim (Optimizer): the optimizer to use for training |
|
|
|
optim (Optimizer): the optimizer to use for training |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
def __init__( |
|
|
|
strategy: Strategy, |
|
|
|
self, |
|
|
|
max_epochs: int, |
|
|
|
strategy: Strategy, |
|
|
|
model: nn.Module, |
|
|
|
max_epochs: int, |
|
|
|
optimizer: Optimizer, |
|
|
|
model: nn.Module, |
|
|
|
) -> None: |
|
|
|
optimizer: Optimizer, |
|
|
|
|
|
|
|
) -> None: |
|
|
|
super().__init__() |
|
|
|
super().__init__() |
|
|
|
self.strategy = strategy |
|
|
|
self.strategy = strategy |
|
|
|
self.max_epochs = max_epochs |
|
|
|
self.max_epochs = max_epochs |
|
|
@ -50,10 +51,7 @@ class SLTrainer(ABC): |
|
|
|
|
|
|
|
|
|
|
|
def fit(self, *args, **kwargs): |
|
|
|
def fit(self, *args, **kwargs): |
|
|
|
self._before_fit(*args, **kwargs) |
|
|
|
self._before_fit(*args, **kwargs) |
|
|
|
for epoch in tqdm.trange(self.max_epochs, |
|
|
|
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar): |
|
|
|
desc="Epochs", |
|
|
|
|
|
|
|
disable=not is_rank_0() or self.no_epoch_bar |
|
|
|
|
|
|
|
): |
|
|
|
|
|
|
|
self._train(epoch) |
|
|
|
self._train(epoch) |
|
|
|
self._eval(epoch) |
|
|
|
self._eval(epoch) |
|
|
|
|
|
|
|
|
|
|
@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC): |
|
|
|
buffer: NaiveReplayBuffer, |
|
|
|
buffer: NaiveReplayBuffer, |
|
|
|
sample_buffer: bool, |
|
|
|
sample_buffer: bool, |
|
|
|
dataloader_pin_memory: bool, |
|
|
|
dataloader_pin_memory: bool, |
|
|
|
callbacks: List[Callback] = [] |
|
|
|
callbacks: List[Callback] = []) -> None: |
|
|
|
) -> None: |
|
|
|
|
|
|
|
super().__init__() |
|
|
|
super().__init__() |
|
|
|
self.strategy = strategy |
|
|
|
self.strategy = strategy |
|
|
|
self.buffer = buffer |
|
|
|
self.buffer = buffer |
|
|
@ -138,7 +135,7 @@ class OnPolicyTrainer(ABC): |
|
|
|
@abstractmethod |
|
|
|
@abstractmethod |
|
|
|
def _learn(self, update_step: int): |
|
|
|
def _learn(self, update_step: int): |
|
|
|
""" |
|
|
|
""" |
|
|
|
Implement this method to learn from experience, either |
|
|
|
Implement this method to learn from experience, either |
|
|
|
sample from buffer or transform buffer into dataloader. |
|
|
|
sample from buffer or transform buffer into dataloader. |
|
|
|
""" |
|
|
|
""" |
|
|
|
raise NotImplementedError() |
|
|
|
raise NotImplementedError() |
|
|
@ -154,13 +151,14 @@ class OnPolicyTrainer(ABC): |
|
|
|
self._learn(update_step) |
|
|
|
self._learn(update_step) |
|
|
|
self._on_learn_epoch_end(update_step) |
|
|
|
self._on_learn_epoch_end(update_step) |
|
|
|
|
|
|
|
|
|
|
|
def fit(self, |
|
|
|
def fit( |
|
|
|
prompt_dataloader: DataLoader, |
|
|
|
self, |
|
|
|
pretrain_dataloader: DataLoader, |
|
|
|
prompt_dataloader: DataLoader, |
|
|
|
num_episodes: int, |
|
|
|
pretrain_dataloader: DataLoader, |
|
|
|
num_collect_steps: int, |
|
|
|
num_episodes: int, |
|
|
|
num_update_steps: int, |
|
|
|
num_collect_steps: int, |
|
|
|
): |
|
|
|
num_update_steps: int, |
|
|
|
|
|
|
|
): |
|
|
|
""" |
|
|
|
""" |
|
|
|
The main training loop of on-policy rl trainers. |
|
|
|
The main training loop of on-policy rl trainers. |
|
|
|
|
|
|
|
|
|
|
@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC): |
|
|
|
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) |
|
|
|
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) |
|
|
|
|
|
|
|
|
|
|
|
with self._fit_ctx(): |
|
|
|
with self._fit_ctx(): |
|
|
|
for episode in tqdm.trange(num_episodes, |
|
|
|
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()): |
|
|
|
desc="Episodes", |
|
|
|
|
|
|
|
disable=not is_rank_0()): |
|
|
|
|
|
|
|
with self._episode_ctx(episode): |
|
|
|
with self._episode_ctx(episode): |
|
|
|
for collect_step in tqdm.trange(num_collect_steps, |
|
|
|
for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()): |
|
|
|
desc="Collect steps", |
|
|
|
|
|
|
|
disable=not is_rank_0()): |
|
|
|
|
|
|
|
self._collect_phase(collect_step) |
|
|
|
self._collect_phase(collect_step) |
|
|
|
if not self.sample_buffer: |
|
|
|
if not self.sample_buffer: |
|
|
|
# HACK(cwher): according to the design of boost API, dataloader should also be boosted, |
|
|
|
# HACK(cwher): according to the design of boost API, dataloader should also be boosted, |
|
|
|
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. |
|
|
|
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. |
|
|
|
# I only call strategy.setup_dataloader() to setup dataloader. |
|
|
|
# I only call strategy.setup_dataloader() to setup dataloader. |
|
|
|
self.dataloader = self.strategy.setup_dataloader(self.buffer, |
|
|
|
self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory) |
|
|
|
self.dataloader_pin_memory) |
|
|
|
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()): |
|
|
|
for update_step in tqdm.trange(num_update_steps, |
|
|
|
|
|
|
|
desc="Update steps", |
|
|
|
|
|
|
|
disable=not is_rank_0()): |
|
|
|
|
|
|
|
self._update_phase(update_step) |
|
|
|
self._update_phase(update_step) |
|
|
|
# NOTE: this is for on-policy algorithms |
|
|
|
# NOTE: this is for on-policy algorithms |
|
|
|
self.buffer.clear() |
|
|
|
self.buffer.clear() |
|
|
|