From 9e512938f6b0b79c2d61c12d4fdc3b4a0008362e Mon Sep 17 00:00:00 2001 From: Zirui Zhu Date: Wed, 19 Jul 2023 22:18:08 +0800 Subject: [PATCH] [NFC] polish applications/Chat/coati/trainer/strategies/base.py code style (#4278) --- .../Chat/coati/trainer/strategies/base.py | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index 80bc32728..3d1dfaf78 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -79,8 +79,7 @@ class Strategy(ABC): model, optimizer = arg except ValueError: raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"') - model, optimizer, *_ = self.booster.boost(model=model, - optimizer=optimizer) + model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer) rets.append((model, optimizer)) elif isinstance(arg, Dict): model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg) @@ -90,10 +89,7 @@ class Strategy(ABC): dataloader=dataloader, lr_scheduler=lr_scheduler) # remove None values - boost_result = { - key: value - for key, value in boost_result.items() if value is not None - } + boost_result = {key: value for key, value in boost_result.items() if value is not None} rets.append(boost_result) else: raise RuntimeError(f'Type {type(arg)} is not supported') @@ -112,23 +108,13 @@ class Strategy(ABC): """ return model - def save_model(self, - model: nn.Module, - path: str, - only_rank0: bool = True, - **kwargs - ) -> None: + def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None: self.booster.save_model(model, path, shard=not only_rank0, **kwargs) def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None: self.booster.load_model(model, path, strict) - def save_optimizer(self, - optimizer: Optimizer, - path: str, - only_rank0: bool = False, - **kwargs - ) -> None: + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None: self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs) def load_optimizer(self, optimizer: Optimizer, path: str) -> None: