[autoparallel] bypass MetaInfo when unavailable and modify BCAST_FUNC_OP metainfo (#2293)

* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline

* [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop

* [autoparallel] specifycomm nodes' memory cost in construct chain

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] bypass metainfo when available and modify BCAST_FUNC_OP
pull/2258/head
Boyuan Yao 2 years ago committed by GitHub
parent 8ea50d999e
commit b904748210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
""" """
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args)) output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
# construct forward args for flop mapping # construct forward args for flop mapping
fwd_in_args = [input_op_data.data, other_op_data.data] fwd_in_args = [opdata.data for opdata in input_op_data]
fwd_out_args = [output_op_data.data] fwd_out_args = [output_op_data.data]
# calculate cost # calculate cost
# calculate compute cost # calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case # NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2 bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost # calculate memory cost
param_mem_cost = activation_size( param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost( fwd_mem_cost = MemoryCost(
activation=activation_size([input_op_data.data, output_op_data.data]), activation=activation_size(output_op_data.data),
parameter=param_mem_cost, parameter=param_mem_cost,
) )
bwd_mem_cost = MemoryCost( bwd_mem_cost = MemoryCost(

@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
OperationDataType, OperationDataType,
@ -234,6 +234,10 @@ class MetaInfoNodeHandler(NodeHandler):
""" """
super().register_strategy(compute_resharding_cost=compute_resharding_cost) super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function() target = self.get_target_function()
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = [] metainfo_vector = []
for strategy in self.strategies_vector: for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target) metainfo = MetaInfo(strategy, target)
@ -281,6 +285,10 @@ class MetaInfoModuleHandler(ModuleHandler):
""" """
super().register_strategy(compute_resharding_cost=compute_resharding_cost) super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function() target = self.get_target_function()
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = [] metainfo_vector = []
for strategy in self.strategies_vector: for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target) metainfo = MetaInfo(strategy, target)

Loading…
Cancel
Save