diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 4b2874fdb..2bb83862e 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,6 +1,7 @@ from functools import partial from typing import Callable, Any, Dict, Tuple import torch +from torch.nn.parameter import Parameter from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map @@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: param_size = 0 def get_param_size(x): - if isinstance(x, torch.nn.parameter): + if isinstance(x, Parameter): param_size += activation_size(x) tree_map(get_param_size, args)