|
|
|
@ -79,8 +79,7 @@ class Strategy(ABC):
|
|
|
|
|
model, optimizer = arg
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
|
|
|
|
|
model, optimizer, *_ = self.booster.boost(model=model,
|
|
|
|
|
optimizer=optimizer)
|
|
|
|
|
model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer)
|
|
|
|
|
rets.append((model, optimizer))
|
|
|
|
|
elif isinstance(arg, Dict):
|
|
|
|
|
model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
|
|
|
|
@ -90,10 +89,7 @@ class Strategy(ABC):
|
|
|
|
|
dataloader=dataloader,
|
|
|
|
|
lr_scheduler=lr_scheduler)
|
|
|
|
|
# remove None values
|
|
|
|
|
boost_result = {
|
|
|
|
|
key: value
|
|
|
|
|
for key, value in boost_result.items() if value is not None
|
|
|
|
|
}
|
|
|
|
|
boost_result = {key: value for key, value in boost_result.items() if value is not None}
|
|
|
|
|
rets.append(boost_result)
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError(f'Type {type(arg)} is not supported')
|
|
|
|
@ -112,23 +108,13 @@ class Strategy(ABC):
|
|
|
|
|
"""
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
def save_model(self,
|
|
|
|
|
model: nn.Module,
|
|
|
|
|
path: str,
|
|
|
|
|
only_rank0: bool = True,
|
|
|
|
|
**kwargs
|
|
|
|
|
) -> None:
|
|
|
|
|
def save_model(self, model: nn.Module, path: str, only_rank0: bool = True, **kwargs) -> None:
|
|
|
|
|
self.booster.save_model(model, path, shard=not only_rank0, **kwargs)
|
|
|
|
|
|
|
|
|
|
def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
|
|
|
|
|
self.booster.load_model(model, path, strict)
|
|
|
|
|
|
|
|
|
|
def save_optimizer(self,
|
|
|
|
|
optimizer: Optimizer,
|
|
|
|
|
path: str,
|
|
|
|
|
only_rank0: bool = False,
|
|
|
|
|
**kwargs
|
|
|
|
|
) -> None:
|
|
|
|
|
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None:
|
|
|
|
|
self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
|
|
|
|
|
|
|
|
|
|
def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
|
|
|
|
|