Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

25 lines
645 B

import torch.nn as nn
class ModelWrapper(nn.Module):
"""
A wrapper class to define the common interface used by booster.
Args:
module (nn.Module): The model to be wrapped.
"""
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def unwrap(self):
"""
Unwrap the model to return the original model for checkpoint saving/loading.
"""
if isinstance(self.module, ModelWrapper):
return self.module.unwrap()
return self.module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)