mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] bypass MetaInfo when unavailable and modify BCAST_FUNC_OP metainfo (#2293)
* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline * [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop * [autoparallel] specifycomm nodes' memory cost in construct chain * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculation * [autoparallel] bypass metainfo when available and modify BCAST_FUNC_OPpull/2258/head
parent
8ea50d999e
commit
b904748210
|
@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
|||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||
"""
|
||||
|
||||
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
|
||||
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
|
||||
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
|
||||
|
||||
# construct forward args for flop mapping
|
||||
fwd_in_args = [input_op_data.data, other_op_data.data]
|
||||
fwd_in_args = [opdata.data for opdata in input_op_data]
|
||||
fwd_out_args = [output_op_data.data]
|
||||
|
||||
# calculate cost
|
||||
|
||||
# calculate compute cost
|
||||
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
|
||||
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
|
||||
bwd_compute_cost = fwd_compute_cost * 2
|
||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||
|
||||
# calculate memory cost
|
||||
param_mem_cost = activation_size(
|
||||
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
|
||||
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
|
||||
fwd_mem_cost = MemoryCost(
|
||||
activation=activation_size([input_op_data.data, output_op_data.data]),
|
||||
activation=activation_size(output_op_data.data),
|
||||
parameter=param_mem_cost,
|
||||
)
|
||||
bwd_mem_cost = MemoryCost(
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
|
|||
import torch
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
OperationData,
|
||||
OperationDataType,
|
||||
|
@ -234,6 +234,10 @@ class MetaInfoNodeHandler(NodeHandler):
|
|||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
|
@ -281,6 +285,10 @@ class MetaInfoModuleHandler(ModuleHandler):
|
|||
"""
|
||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||
target = self.get_target_function()
|
||||
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||
# is not patched, we will use the default cost model to compute the cost.
|
||||
# TODO: patch all torch functions and modules to make it clean
|
||||
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||
metainfo_vector = []
|
||||
for strategy in self.strategies_vector:
|
||||
metainfo = MetaInfo(strategy, target)
|
||||
|
|
Loading…
Reference in New Issue