[NFC] polish applications/Chat/coati/trainer/strategies/base.py code style (#4278)

pull/4338/head
Zirui Zhu 1 year ago committed by binmakeswell
parent c972d65311
commit 9e512938f6

@ -79,8 +79,7 @@ class Strategy(ABC):
model, optimizer = arg
except ValueError:
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
model, optimizer, *_ = self.booster.boost(model=model,
optimizer=optimizer)
model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer)
rets.append((model, optimizer))
elif isinstance(arg, Dict):
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
@ -90,10 +89,7 @@ class Strategy(ABC):
dataloader=dataloader,
lr_scheduler=lr_scheduler)
# remove None values
boost_result = {
key: value
for key, value in boost_result.items() if value is not None
}
boost_result = {key: value for key, value in boost_result.items() if value is not None}
rets.append(boost_result)
else:
raise RuntimeError(f'Type {type(arg)} is not supported')
@ -112,23 +108,13 @@ class Strategy(ABC):
"""
return model
def save_model(self,
model: nn.Module,
path: str,
only_rank0: bool = True,
**kwargs
) -> None:
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
self.booster.load_model(model, path, strict)
def save_optimizer(self,
optimizer: Optimizer,
path: str,
only_rank0: bool = False,
**kwargs
) -> None:
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None:
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:

Loading…
Cancel
Save