[hotfix] fix wrong type name in profiler (#1678)

pull/1680/head
Boyuan Yao 2022-10-05 21:59:05 +08:00 committed by GitHub
parent 132b4306b7
commit d8420f81a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -1,6 +1,7 @@
from functools import partial from functools import partial
from typing import Callable, Any, Dict, Tuple from typing import Callable, Any, Dict, Tuple
import torch import torch
from torch.nn.parameter import Parameter
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
param_size = 0 param_size = 0
def get_param_size(x): def get_param_size(x):
if isinstance(x, torch.nn.parameter): if isinstance(x, Parameter):
param_size += activation_size(x) param_size += activation_size(x)
tree_map(get_param_size, args) tree_map(get_param_size, args)