mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish applications/Chat/coati/trainer/strategies/base.py code style (#4278)
parent
c972d65311
commit
9e512938f6
|
@ -79,8 +79,7 @@ class Strategy(ABC):
|
|||
model, optimizer = arg
|
||||
except ValueError:
|
||||
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
|
||||
model, optimizer, *_ = self.booster.boost(model=model,
|
||||
optimizer=optimizer)
|
||||
model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer)
|
||||
rets.append((model, optimizer))
|
||||
elif isinstance(arg, Dict):
|
||||
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
|
||||
|
@ -90,10 +89,7 @@ class Strategy(ABC):
|
|||
dataloader=dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
# remove None values
|
||||
boost_result = {
|
||||
key: value
|
||||
for key, value in boost_result.items() if value is not None
|
||||
}
|
||||
boost_result = {key: value for key, value in boost_result.items() if value is not None}
|
||||
rets.append(boost_result)
|
||||
else:
|
||||
raise RuntimeError(f'Type {type(arg)} is not supported')
|
||||
|
@ -112,23 +108,13 @@ class Strategy(ABC):
|
|||
"""
|
||||
return model
|
||||
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
**kwargs
|
||||
) -> None:
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
|
||||
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
|
||||
self.booster.load_model(model, path, strict)
|
||||
|
||||
def save_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None:
|
||||
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
|
||||
|
|
Loading…
Reference in New Issue