[hotfix] fix wrong type name in profiler (#1678)

pull/1680/head
Boyuan Yao 2022-10-05 21:59:05 +08:00 committed by GitHub
parent 132b4306b7
commit d8420f81a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -1,6 +1,7 @@
from functools import partial
from typing import Callable, Any, Dict, Tuple
import torch
from torch.nn.parameter import Parameter
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map
@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
param_size = 0
def get_param_size(x):
if isinstance(x, torch.nn.parameter):
if isinstance(x, Parameter):
param_size += activation_size(x)
tree_map(get_param_size, args)