mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix wrong type name in profiler (#1678)
parent
132b4306b7
commit
d8420f81a4
|
@ -1,6 +1,7 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Any, Dict, Tuple
|
from typing import Callable, Any, Dict, Tuple
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
from torch.fx.node import Argument, Target
|
from torch.fx.node import Argument, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
@ -298,7 +299,7 @@ def profile_function(target: 'Target', device: str = 'meta') -> Callable:
|
||||||
param_size = 0
|
param_size = 0
|
||||||
|
|
||||||
def get_param_size(x):
|
def get_param_size(x):
|
||||||
if isinstance(x, torch.nn.parameter):
|
if isinstance(x, Parameter):
|
||||||
param_size += activation_size(x)
|
param_size += activation_size(x)
|
||||||
|
|
||||||
tree_map(get_param_size, args)
|
tree_map(get_param_size, args)
|
||||||
|
|
Loading…
Reference in New Issue