Browse Source

[hotfix] fix wrong type name in profiler (#1678)

pull/1680/head
Boyuan Yao 2 years ago committed by GitHub
parent
commit
d8420f81a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/fx/profiler/profiler.py

3
colossalai/fx/profiler/profiler.py

@ -1,6 +1,7 @@
from functools import partial from functools import partial
from typing import Callable, Any, Dict, Tuple from typing import Callable, Any, Dict, Tuple
import torch import torch
from torch.nn.parameter import Parameter
from torch.fx import Graph, Node from torch.fx import Graph, Node
from torch.fx.node import Argument, Target from torch.fx.node import Argument, Target
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
param_size = 0 param_size = 0
def get_param_size(x): def get_param_size(x):
if isinstance(x, torch.nn.parameter): if isinstance(x, Parameter):
param_size += activation_size(x) param_size += activation_size(x)
tree_map(get_param_size, args) tree_map(get_param_size, args)

Loading…
Cancel
Save