@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
import torch
from torch . fx . node import Node
from colossalai . auto_parallel . meta_profiler . metainfo import MetaInfo
from colossalai . auto_parallel . meta_profiler . metainfo import MetaInfo , meta_register
from colossalai . auto_parallel . tensor_shard . sharding_strategy import (
OperationData ,
OperationDataType ,
@ -234,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler):
"""
super ( ) . register_strategy ( compute_resharding_cost = compute_resharding_cost )
target = self . get_target_function ( )
metainfo_vector = [ ]
for strategy in self . strategies_vector :
metainfo = MetaInfo ( strategy , target )
strategy . compute_cost = metainfo . compute_cost
strategy . memory_cost = metainfo . memory_cost
metainfo_vector . append ( metainfo )
# attach metainfos to the handler
setattr ( self , " metainfo_vector " , metainfo_vector )
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register . has ( target . __class__ ) or meta_register . has ( target ) :
metainfo_vector = [ ]
for strategy in self . strategies_vector :
metainfo = MetaInfo ( strategy , target )
strategy . compute_cost = metainfo . compute_cost
strategy . memory_cost = metainfo . memory_cost
metainfo_vector . append ( metainfo )
# attach metainfos to the handler
setattr ( self , " metainfo_vector " , metainfo_vector )
return self . strategies_vector
@ -281,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler):
"""
super ( ) . register_strategy ( compute_resharding_cost = compute_resharding_cost )
target = self . get_target_function ( )
metainfo_vector = [ ]
for strategy in self . strategies_vector :
metainfo = MetaInfo ( strategy , target )
strategy . compute_cost = metainfo . compute_cost
strategy . memory_cost = metainfo . memory_cost
metainfo_vector . append ( metainfo )
# attach metainfos to the handler
setattr ( self , " metainfo_vector " , metainfo_vector )
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register . has ( target . __class__ ) or meta_register . has ( target ) :
metainfo_vector = [ ]
for strategy in self . strategies_vector :
metainfo = MetaInfo ( strategy , target )
strategy . compute_cost = metainfo . compute_cost
strategy . memory_cost = metainfo . memory_cost
metainfo_vector . append ( metainfo )
# attach metainfos to the handler
setattr ( self , " metainfo_vector " , metainfo_vector )
return self . strategies_vector