Browse Source

[NFC] polish applications/Chat/coati/trainer/base.py code style (#4260)

pull/4338/head
shenggan 1 year ago committed by binmakeswell
parent
commit
798cb72907
  1. 53
      applications/Chat/coati/trainer/base.py

53
applications/Chat/coati/trainer/base.py

@ -25,12 +25,13 @@ class SLTrainer(ABC):
optim (Optimizer): the optimizer to use for training optim (Optimizer): the optimizer to use for training
""" """
def __init__(self, def __init__(
strategy: Strategy, self,
max_epochs: int, strategy: Strategy,
model: nn.Module, max_epochs: int,
optimizer: Optimizer, model: nn.Module,
) -> None: optimizer: Optimizer,
) -> None:
super().__init__() super().__init__()
self.strategy = strategy self.strategy = strategy
self.max_epochs = max_epochs self.max_epochs = max_epochs
@ -50,10 +51,7 @@ class SLTrainer(ABC):
def fit(self, *args, **kwargs): def fit(self, *args, **kwargs):
self._before_fit(*args, **kwargs) self._before_fit(*args, **kwargs)
for epoch in tqdm.trange(self.max_epochs, for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
desc="Epochs",
disable=not is_rank_0() or self.no_epoch_bar
):
self._train(epoch) self._train(epoch)
self._eval(epoch) self._eval(epoch)
@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC):
buffer: NaiveReplayBuffer, buffer: NaiveReplayBuffer,
sample_buffer: bool, sample_buffer: bool,
dataloader_pin_memory: bool, dataloader_pin_memory: bool,
callbacks: List[Callback] = [] callbacks: List[Callback] = []) -> None:
) -> None:
super().__init__() super().__init__()
self.strategy = strategy self.strategy = strategy
self.buffer = buffer self.buffer = buffer
@ -138,7 +135,7 @@ class OnPolicyTrainer(ABC):
@abstractmethod @abstractmethod
def _learn(self, update_step: int): def _learn(self, update_step: int):
""" """
Implement this method to learn from experience, either Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader. sample from buffer or transform buffer into dataloader.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -154,13 +151,14 @@ class OnPolicyTrainer(ABC):
self._learn(update_step) self._learn(update_step)
self._on_learn_epoch_end(update_step) self._on_learn_epoch_end(update_step)
def fit(self, def fit(
prompt_dataloader: DataLoader, self,
pretrain_dataloader: DataLoader, prompt_dataloader: DataLoader,
num_episodes: int, pretrain_dataloader: DataLoader,
num_collect_steps: int, num_episodes: int,
num_update_steps: int, num_collect_steps: int,
): num_update_steps: int,
):
""" """
The main training loop of on-policy rl trainers. The main training loop of on-policy rl trainers.
@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC):
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
with self._fit_ctx(): with self._fit_ctx():
for episode in tqdm.trange(num_episodes, for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
desc="Episodes",
disable=not is_rank_0()):
with self._episode_ctx(episode): with self._episode_ctx(episode):
for collect_step in tqdm.trange(num_collect_steps, for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
desc="Collect steps",
disable=not is_rank_0()):
self._collect_phase(collect_step) self._collect_phase(collect_step)
if not self.sample_buffer: if not self.sample_buffer:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted, # HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted. # but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader. # I only call strategy.setup_dataloader() to setup dataloader.
self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
self.dataloader_pin_memory) for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
for update_step in tqdm.trange(num_update_steps,
desc="Update steps",
disable=not is_rank_0()):
self._update_phase(update_step) self._update_phase(update_step)
# NOTE: this is for on-policy algorithms # NOTE: this is for on-policy algorithms
self.buffer.clear() self.buffer.clear()

Loading…
Cancel
Save