mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix wrong type name in profiler (#1678)
parent
132b4306b7
commit
d8420f81a4
|
@ -1,6 +1,7 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Any, Dict, Tuple
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.fx import Graph, Node
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
|
@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
|||
param_size = 0
|
||||
|
||||
def get_param_size(x):
|
||||
if isinstance(x, torch.nn.parameter):
|
||||
if isinstance(x, Parameter):
|
||||
param_size += activation_size(x)
|
||||
|
||||
tree_map(get_param_size, args)
|
||||
|
|
Loading…
Reference in New Issue