@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
Tuple [ TrainCycleItem , TrainCycleItem , List [ torch . Tensor ] ] : compute cost , memory cost and forward inputs
Tuple [ TrainCycleItem , TrainCycleItem , List [ torch . Tensor ] ] : compute cost , memory cost and forward inputs
"""
"""
input_op_data , other_op_data = [ arg for arg in args if arg . type != OperationDataType . OUTPUT ]
input_op_data = [ arg for arg in args if arg . type != OperationDataType . OUTPUT ]
output_op_data = next ( filter ( lambda arg : arg . type == OperationDataType . OUTPUT , args ) )
output_op_data = next ( filter ( lambda arg : arg . type == OperationDataType . OUTPUT , args ) )
# construct forward args for flop mapping
# construct forward args for flop mapping
fwd_in_args = [ input_op_data. data , other_op_data . data]
fwd_in_args = [ opdata. data for opdata in input_op_ data]
fwd_out_args = [ output_op_data . data ]
fwd_out_args = [ output_op_data . data ]
# calculate cost
# calculate cost
# calculate compute cost
# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping [ torch . ops . aten . _adaptive_avg_pool2d. default ] ( fwd_in_args , fwd_out_args )
fwd_compute_cost = flop_mapping [ torch . ops . aten . add. Tensor ] ( fwd_in_args , fwd_out_args )
bwd_compute_cost = fwd_compute_cost * 2
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem ( fwd = fwd_compute_cost , bwd = bwd_compute_cost , total = fwd_compute_cost + bwd_compute_cost )
compute_cost = TrainCycleItem ( fwd = fwd_compute_cost , bwd = bwd_compute_cost , total = fwd_compute_cost + bwd_compute_cost )
# calculate memory cost
# calculate memory cost
param_mem_cost = activation_size (
param_mem_cost = activation_size ( [ arg . data for arg in input_op_data if arg . type == OperationDataType . PARAM ] )
[ arg . data for arg in [ input_op_data , other_op_data ] if arg . type == OperationDataType . PARAM ] )
fwd_mem_cost = MemoryCost (
fwd_mem_cost = MemoryCost (
activation = activation_size ( [ input_op_data . data , output_op_data . data ] ) ,
activation = activation_size ( output_op_data . data ) ,
parameter = param_mem_cost ,
parameter = param_mem_cost ,
)
)
bwd_mem_cost = MemoryCost (
bwd_mem_cost = MemoryCost (