From d8420f81a437ad07aa0c3993d2a8ff9d1ede6d16 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Wed, 5 Oct 2022 21:59:05 +0800 Subject: [PATCH] [hotfix] fix wrong type name in profiler (#1678) --- colossalai/fx/profiler/profiler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/fx/profiler/profiler.py b/colossalai/fx/profiler/profiler.py index 4b2874fdb..2bb83862e 100644 --- a/colossalai/fx/profiler/profiler.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,6 +1,7 @@ from functools import partial from typing import Callable, Any, Dict, Tuple import torch +from torch.nn.parameter import Parameter from torch.fx import Graph, Node from torch.fx.node import Argument, Target from torch.utils._pytree import tree_map @@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable: param_size = 0 def get_param_size(x): - if isinstance(x, torch.nn.parameter): + if isinstance(x, Parameter): param_size += activation_size(x) tree_map(get_param_size, args)