from .utils import InsertPostInitMethodToModuleSubClasses import torch from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \ ColoLinear import types from torch import nn from typing import Iterator, Tuple, Union, Optional # find named_params includes replica def _named_params_with_replica( module: nn.Module, prefix: str = '', recurse: bool = True, ) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] for mod_prefix, mod in modules: for name, val in mod._parameters.items(): if val is None: continue name = mod_prefix + ('.' if mod_prefix else '') + name yield name, val # Adapted from torch.nn.module.Module.register_param def _register_parameter_with_colotensor(self, name: str, param): if '_parameters' not in self.__dict__: raise AttributeError("cannot assign parameter before Module.__init__() call") if not isinstance(name, torch._six.string_classes): raise TypeError("parameter name should be a string. " "Got {}".format(torch.typename(name))) if '.' in name: raise KeyError("parameter name can't contain \".\"") if name == '': raise KeyError("parameter name can't be empty string \"\"") if hasattr(self, name) and name not in self._parameters: raise KeyError("attribute '{}' already exists".format(name)) if param is None: self._parameters[name] = None elif not isinstance(param, (torch.nn.Parameter, ColoParameter)): raise TypeError("cannot assign '{}' object to parameter '{}' " "(torch.nn.Parameter or ColoParameter or None required)".format(torch.typename(param), name)) elif param.grad_fn: raise ValueError("Cannot assign non-leaf Tensor to parameter '{0}'. Model " "parameters must be created explicitly. To express '{0}' " "as a function of another Tensor, compute the value in " "the forward() method.".format(name)) else: self._parameters[name] = param # Adapted from torch.nn.module.Module.__setattr__ def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]): def remove_from(*dicts_or_sets): for d in dicts_or_sets: if name in d: if isinstance(d, dict): del d[name] else: d.discard(name) params = self.__dict__.get('_parameters') if isinstance(value, (ColoParameter, torch.nn.Parameter)): if params is None: raise AttributeError("cannot assign parameters before Module.__init__() call") remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set) self.register_parameter(name, value) elif params is not None and name in params: if value is not None: raise TypeError("cannot assign '{}' as parameter '{}' " "(torch.nn.Parameter or None expected)".format(torch.typename(value), name)) self.register_parameter(name, value) else: modules = self.__dict__.get('_modules') if isinstance(value, torch.nn.Module): if modules is None: raise AttributeError("cannot assign module before Module.__init__() call") remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set) modules[name] = value elif modules is not None and name in modules: if value is not None: raise TypeError("cannot assign '{}' as child module '{}' " "(torch.nn.Module or None expected)".format(torch.typename(value), name)) modules[name] = value else: buffers = self.__dict__.get('_buffers') if buffers is not None and name in buffers: if value is not None and not isinstance(value, torch.Tensor): raise TypeError("cannot assign '{}' as buffer '{}' " "(torch.Tensor or None expected)".format(torch.typename(value), name)) buffers[name] = value else: object.__setattr__(self, name, value) def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]: module_path, _, param_name = target.rpartition(".") mod: torch.nn.Module = self.get_submodule(module_path) if not hasattr(mod, param_name): raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") param = getattr(mod, param_name) return param def ColoModulize(module): """ Replacing the parameters() and named_parameters() with our customized ones """ module._colo_visited = True class ColoInitContext(InsertPostInitMethodToModuleSubClasses): def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')): """ Args: lazy_memory_allocate (bool, optional): whether to allocate memory for the parameter tensors. Defaults to False. device (torch.device, optional): the device parameters initialized are resident on. Defaults to torch.device('cpu'). """ super().__init__() self._lazy_memory_allocate = lazy_memory_allocate self._device = device torch.nn.Module.__setattr__ = _setattr_with_colotensor torch.nn.Module.register_parameter = _register_parameter_with_colotensor torch.nn.Module.get_parameter = _get_parameter_with_colotensor register_colo_module(torch.nn.Linear, ColoLinear()) def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ The function to call at the end of the constructor of each module. FIXME(fjr) The module may be passed to this function multiple times? """ if hasattr(module, '_colo_visited'): return name_list = [] for name, param in _named_params_with_replica(module): if isinstance(param, ColoTensor): continue split = name.rfind('.') if split >= 0: # param in submodule module_name = name[:split] param_name = name[split + 1:] else: module_name = '' # param in current module param_name = name name_list.append((module_name, param_name)) replaced_tensors = dict( ) # record mapping between (torch.Tensor, ColoTensor) to distinguish the same reference for module_name, param_name in name_list: submodule = module.get_submodule(module_name) param = submodule.get_parameter(param_name) if param in replaced_tensors: colo_param = replaced_tensors[param] else: save_torch_payload = True if not self._lazy_memory_allocate else False # detaching tensor is necessary for optimizers. requires_grad = param.requires_grad colo_param = ColoParameter(param.to(self._device), requires_grad=requires_grad) # add mapping record replaced_tensors[param] = colo_param delattr(submodule, param_name) setattr(submodule, param_name, colo_param) ColoModulize(module)