@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
import torch
import torch
import torch . cuda
import torch . cuda
from packaging . version import Version
from torch . nn import Module
from torch . nn import Module
from torch . utils . _pytree import SUPPORTED_NODES , TreeSpec , _register_pytree_node, tree_flatten, tree_map , tree_unflatten
from torch . utils . _pytree import SUPPORTED_NODES , TreeSpec , tree_flatten, tree_map , tree_unflatten
# this register are for torch under version 1.13.1, maybe removed in the future
# this register are for torch under version 1.13.1, maybe removed in the future
@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
return OrderedDict ( ( key , value ) for key , value in zip ( context , values ) )
return OrderedDict ( ( key , value ) for key , value in zip ( context , values ) )
_register_pytree_node ( OrderedDict , _odict_flatten , _odict_unflatten )
if Version ( torch . __version__ ) < = Version ( " 1.13.1 " ) :
try :
from torch . utils . _pytree import register_pytree_node as _register_pytree_node
except ImportError :
from torch . utils . _pytree import _register_pytree_node
_register_pytree_node ( OrderedDict , _odict_flatten , _odict_unflatten )
def tree_map_hf ( fn : Any , pytree : Any ) :
def tree_map_hf ( fn : Any , pytree : Any ) :