From 798cb72907f5425dccc94c83fd1a900a6b8f67eb Mon Sep 17 00:00:00 2001 From: shenggan Date: Tue, 18 Jul 2023 10:59:57 +0800 Subject: [PATCH] [NFC] polish applications/Chat/coati/trainer/base.py code style (#4260) --- applications/Chat/coati/trainer/base.py | 53 ++++++++++--------------- 1 file changed, 22 insertions(+), 31 deletions(-) diff --git a/applications/Chat/coati/trainer/base.py b/applications/Chat/coati/trainer/base.py index 13571cdcc..b4d168a56 100644 --- a/applications/Chat/coati/trainer/base.py +++ b/applications/Chat/coati/trainer/base.py @@ -25,12 +25,13 @@ class SLTrainer(ABC): optim (Optimizer): the optimizer to use for training """ - def __init__(self, - strategy: Strategy, - max_epochs: int, - model: nn.Module, - optimizer: Optimizer, - ) -> None: + def __init__( + self, + strategy: Strategy, + max_epochs: int, + model: nn.Module, + optimizer: Optimizer, + ) -> None: super().__init__() self.strategy = strategy self.max_epochs = max_epochs @@ -50,10 +51,7 @@ class SLTrainer(ABC): def fit(self, *args, **kwargs): self._before_fit(*args, **kwargs) - for epoch in tqdm.trange(self.max_epochs, - desc="Epochs", - disable=not is_rank_0() or self.no_epoch_bar - ): + for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar): self._train(epoch) self._eval(epoch) @@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC): buffer: NaiveReplayBuffer, sample_buffer: bool, dataloader_pin_memory: bool, - callbacks: List[Callback] = [] - ) -> None: + callbacks: List[Callback] = []) -> None: super().__init__() self.strategy = strategy self.buffer = buffer @@ -138,7 +135,7 @@ class OnPolicyTrainer(ABC): @abstractmethod def _learn(self, update_step: int): """ - Implement this method to learn from experience, either + Implement this method to learn from experience, either sample from buffer or transform buffer into dataloader. """ raise NotImplementedError() @@ -154,13 +151,14 @@ class OnPolicyTrainer(ABC): self._learn(update_step) self._on_learn_epoch_end(update_step) - def fit(self, - prompt_dataloader: DataLoader, - pretrain_dataloader: DataLoader, - num_episodes: int, - num_collect_steps: int, - num_update_steps: int, - ): + def fit( + self, + prompt_dataloader: DataLoader, + pretrain_dataloader: DataLoader, + num_episodes: int, + num_collect_steps: int, + num_update_steps: int, + ): """ The main training loop of on-policy rl trainers. @@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC): self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) with self._fit_ctx(): - for episode in tqdm.trange(num_episodes, - desc="Episodes", - disable=not is_rank_0()): + for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()): with self._episode_ctx(episode): - for collect_step in tqdm.trange(num_collect_steps, - desc="Collect steps", - disable=not is_rank_0()): + for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()): self._collect_phase(collect_step) if not self.sample_buffer: # HACK(cwher): according to the design of boost API, dataloader should also be boosted, # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. # I only call strategy.setup_dataloader() to setup dataloader. - self.dataloader = self.strategy.setup_dataloader(self.buffer, - self.dataloader_pin_memory) - for update_step in tqdm.trange(num_update_steps, - desc="Update steps", - disable=not is_rank_0()): + self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory) + for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()): self._update_phase(update_step) # NOTE: this is for on-policy algorithms self.buffer.clear()