mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish applications/Chat/coati/trainer/base.py code style (#4260)
parent
b2debdc09b
commit
798cb72907
|
@ -25,12 +25,13 @@ class SLTrainer(ABC):
|
|||
optim (Optimizer): the optimizer to use for training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
max_epochs: int,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
strategy: Strategy,
|
||||
max_epochs: int,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.max_epochs = max_epochs
|
||||
|
@ -50,10 +51,7 @@ class SLTrainer(ABC):
|
|||
|
||||
def fit(self, *args, **kwargs):
|
||||
self._before_fit(*args, **kwargs)
|
||||
for epoch in tqdm.trange(self.max_epochs,
|
||||
desc="Epochs",
|
||||
disable=not is_rank_0() or self.no_epoch_bar
|
||||
):
|
||||
for epoch in tqdm.trange(self.max_epochs, desc="Epochs", disable=not is_rank_0() or self.no_epoch_bar):
|
||||
self._train(epoch)
|
||||
self._eval(epoch)
|
||||
|
||||
|
@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC):
|
|||
buffer: NaiveReplayBuffer,
|
||||
sample_buffer: bool,
|
||||
dataloader_pin_memory: bool,
|
||||
callbacks: List[Callback] = []
|
||||
) -> None:
|
||||
callbacks: List[Callback] = []) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.buffer = buffer
|
||||
|
@ -138,7 +135,7 @@ class OnPolicyTrainer(ABC):
|
|||
@abstractmethod
|
||||
def _learn(self, update_step: int):
|
||||
"""
|
||||
Implement this method to learn from experience, either
|
||||
Implement this method to learn from experience, either
|
||||
sample from buffer or transform buffer into dataloader.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
@ -154,13 +151,14 @@ class OnPolicyTrainer(ABC):
|
|||
self._learn(update_step)
|
||||
self._on_learn_epoch_end(update_step)
|
||||
|
||||
def fit(self,
|
||||
prompt_dataloader: DataLoader,
|
||||
pretrain_dataloader: DataLoader,
|
||||
num_episodes: int,
|
||||
num_collect_steps: int,
|
||||
num_update_steps: int,
|
||||
):
|
||||
def fit(
|
||||
self,
|
||||
prompt_dataloader: DataLoader,
|
||||
pretrain_dataloader: DataLoader,
|
||||
num_episodes: int,
|
||||
num_collect_steps: int,
|
||||
num_update_steps: int,
|
||||
):
|
||||
"""
|
||||
The main training loop of on-policy rl trainers.
|
||||
|
||||
|
@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC):
|
|||
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
|
||||
|
||||
with self._fit_ctx():
|
||||
for episode in tqdm.trange(num_episodes,
|
||||
desc="Episodes",
|
||||
disable=not is_rank_0()):
|
||||
for episode in tqdm.trange(num_episodes, desc="Episodes", disable=not is_rank_0()):
|
||||
with self._episode_ctx(episode):
|
||||
for collect_step in tqdm.trange(num_collect_steps,
|
||||
desc="Collect steps",
|
||||
disable=not is_rank_0()):
|
||||
for collect_step in tqdm.trange(num_collect_steps, desc="Collect steps", disable=not is_rank_0()):
|
||||
self._collect_phase(collect_step)
|
||||
if not self.sample_buffer:
|
||||
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
|
||||
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
|
||||
# I only call strategy.setup_dataloader() to setup dataloader.
|
||||
self.dataloader = self.strategy.setup_dataloader(self.buffer,
|
||||
self.dataloader_pin_memory)
|
||||
for update_step in tqdm.trange(num_update_steps,
|
||||
desc="Update steps",
|
||||
disable=not is_rank_0()):
|
||||
self.dataloader = self.strategy.setup_dataloader(self.buffer, self.dataloader_pin_memory)
|
||||
for update_step in tqdm.trange(num_update_steps, desc="Update steps", disable=not is_rank_0()):
|
||||
self._update_phase(update_step)
|
||||
# NOTE: this is for on-policy algorithms
|
||||
self.buffer.clear()
|
||||
|
|
Loading…
Reference in New Issue