Browse Source

[hotfix] fix wrong type name in profiler (#1678)

pull/1680/head
Boyuan Yao 2 years ago committed by GitHub
parent
commit
d8420f81a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/fx/profiler/profiler.py

3
colossalai/fx/profiler/profiler.py

@ -1,6 +1,7 @@
from functools import partial
from typing import Callable, Any, Dict, Tuple
import torch
from torch.nn.parameter import Parameter
from torch.fx import Graph, Node
from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map
@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
param_size = 0
def get_param_size(x):
if isinstance(x, torch.nn.parameter):
if isinstance(x, Parameter):
param_size += activation_size(x)
tree_map(get_param_size, args)

Loading…
Cancel
Save