import os import warnings import torch import torch.nn as nn from torch.nn.modules.module import _addindent from typing import Type, Dict, List, Any, Union, Optional, Set from pathlib import Path try: from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode COLOGM = True except: from torch.fx.graph_module import GraphModule from torch.fx.graph import Graph COLOGM = False if COLOGM: class ColoGraphModule(GraphModule): def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): super().__init__(root, graph, class_name) def bind(self, ckpt_def, globals): """Bind checkpoint functions to ColoGraphModule We need to bind our checkpoint functions to the GraphModule so that we could correctly use self.checkpoint for GraphModule forward """ ckpt_code = "\n".join(ckpt_def) globals_copy = globals.copy() _exec_with_source(ckpt_code, globals_copy) func_list = [func for func in globals_copy.keys() if "checkpoint" in func] for func in func_list: tmp_func = globals_copy[func] setattr(self, func, tmp_func.__get__(self, self.__class__)) del globals_copy[func] def recompile(self) -> PythonCode: """ Recompile this GraphModule from its ``graph`` attribute. This should be called after editing the contained ``graph``, otherwise the generated code of this ``GraphModule`` will be out of date. """ if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec python_code = self._graph.python_code(root_module='self') self._code = python_code.src # To split ckpt functions code and forward code _code_list = self._code.split("\n") _fwd_def = [item for item in _code_list if "def forward" in item][0] _fwd_idx = _code_list.index(_fwd_def) ckpt_def = _code_list[:_fwd_idx] self._code = "\n".join(_code_list[_fwd_idx:]) self.bind(ckpt_def, python_code.globals) cls = type(self) cls.forward = _forward_from_src(self._code, python_code.globals) # Determine whether this class explicitly defines a __call__ implementation # to wrap. If it does, save it in order to have wrapped_call invoke it. # If it does not, wrapped_call can use a dynamic call to super() instead. # In most cases, super().__call__ should be torch.nn.Module.__call__. # We do not want to hold a reference to Module.__call__ here; doing so will # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. cls_call = cls.__call__ if "__call__" in vars(cls) else None if '_wrapped_call' not in vars(cls): cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] def call_wrapped(self, *args, **kwargs): return self._wrapped_call(self, *args, **kwargs) cls.__call__ = call_wrapped # reset self._code to original src, otherwise to_folder will be wrong self._code = python_code.src return python_code def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): """Dumps out module to ``folder`` with ``module_name`` so that it can be imported with ``from import `` Args: folder (Union[str, os.PathLike]): The folder to write the code out to module_name (str): Top-level name to use for the ``Module`` while writing out the code """ folder = Path(folder) Path(folder).mkdir(exist_ok=True) torch.save(self.state_dict(), folder / 'state_dict.pt') tab = " " * 4 # we add import colossalai here model_str = f""" import torch from torch.nn import * import colossalai class {module_name}(torch.nn.Module): def __init__(self): super().__init__() """ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: safe_reprs = [ nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d ] if type(module) in safe_reprs: return f"{module.__repr__()}" else: return None blobified_modules = [] for module_name, module in self.named_children(): module_str = _gen_model_repr(module_name, module) if module_str is None: module_file = folder / f'{module_name}.pt' torch.save(module, module_file) blobified_modules.append(module_name) module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ') module_str = f"torch.load(r'{module_file}') # {module_repr}" model_str += f"{tab*2}self.{module_name} = {module_str}\n" for buffer_name, buffer in self._buffers.items(): if buffer is None: continue model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" for param_name, param in self._parameters.items(): if param is None: continue model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" model_str += f"{_addindent(self.code, 4)}\n" module_file = folder / 'module.py' module_file.write_text(model_str) init_file = folder / '__init__.py' init_file.write_text('from .module import *') if len(blobified_modules) > 0: warnings.warn("Was not able to save the following children modules as reprs -" f"saved as pickled files instead: {blobified_modules}") else: class ColoGraphModule(GraphModule): def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): super().__init__(root, graph, class_name)