|
|
@ -1,6 +1,7 @@ |
|
|
|
from functools import partial |
|
|
|
from functools import partial |
|
|
|
from typing import Callable, Any, Dict, Tuple |
|
|
|
from typing import Callable, Any, Dict, Tuple |
|
|
|
import torch |
|
|
|
import torch |
|
|
|
|
|
|
|
from torch.nn.parameter import Parameter |
|
|
|
from torch.fx import Graph, Node |
|
|
|
from torch.fx import Graph, Node |
|
|
|
from torch.fx.node import Argument, Target |
|
|
|
from torch.fx.node import Argument, Target |
|
|
|
from torch.utils._pytree import tree_map |
|
|
|
from torch.utils._pytree import tree_map |
|
|
@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: |
|
|
|
param_size = 0 |
|
|
|
param_size = 0 |
|
|
|
|
|
|
|
|
|
|
|
def get_param_size(x): |
|
|
|
def get_param_size(x): |
|
|
|
if isinstance(x, torch.nn.parameter): |
|
|
|
if isinstance(x, Parameter): |
|
|
|
param_size += activation_size(x) |
|
|
|
param_size += activation_size(x) |
|
|
|
|
|
|
|
|
|
|
|
tree_map(get_param_size, args) |
|
|
|
tree_map(get_param_size, args) |
|
|
|