[autoparallel] bypass MetaInfo when unavailable and modify BCAST_FUNC_OP metainfo (#2293)

* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline

* [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop

* [autoparallel] specifycomm nodes' memory cost in construct chain

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] bypass metainfo when available and modify BCAST_FUNC_OP
pull/2258/head
Boyuan Yao 2 years ago committed by GitHub
parent 8ea50d999e
commit b904748210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
""" """
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT] input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args)) output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
# construct forward args for flop mapping # construct forward args for flop mapping
fwd_in_args = [input_op_data.data, other_op_data.data] fwd_in_args = [opdata.data for opdata in input_op_data]
fwd_out_args = [output_op_data.data] fwd_out_args = [output_op_data.data]
# calculate cost # calculate cost
# calculate compute cost # calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case # NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args) fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2 bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost # calculate memory cost
param_mem_cost = activation_size( param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost( fwd_mem_cost = MemoryCost(
activation=activation_size([input_op_data.data, output_op_data.data]), activation=activation_size(output_op_data.data),
parameter=param_mem_cost, parameter=param_mem_cost,
) )
bwd_mem_cost = MemoryCost( bwd_mem_cost = MemoryCost(

@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
import torch import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
OperationDataType, OperationDataType,
@ -234,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler):
""" """
super().register_strategy(compute_resharding_cost=compute_resharding_cost) super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function() target = self.get_target_function()
metainfo_vector = [] # Currently we haven't patched all the torch functions and modules, so if the target
for strategy in self.strategies_vector: # is not patched, we will use the default cost model to compute the cost.
metainfo = MetaInfo(strategy, target) # TODO: patch all torch functions and modules to make it clean
strategy.compute_cost = metainfo.compute_cost if meta_register.has(target.__class__) or meta_register.has(target):
strategy.memory_cost = metainfo.memory_cost metainfo_vector = []
metainfo_vector.append(metainfo) for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
# attach metainfos to the handler strategy.compute_cost = metainfo.compute_cost
setattr(self, "metainfo_vector", metainfo_vector) strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector return self.strategies_vector
@ -281,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler):
""" """
super().register_strategy(compute_resharding_cost=compute_resharding_cost) super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function() target = self.get_target_function()
metainfo_vector = [] # Currently we haven't patched all the torch functions and modules, so if the target
for strategy in self.strategies_vector: # is not patched, we will use the default cost model to compute the cost.
metainfo = MetaInfo(strategy, target) # TODO: patch all torch functions and modules to make it clean
strategy.compute_cost = metainfo.compute_cost if meta_register.has(target.__class__) or meta_register.has(target):
strategy.memory_cost = metainfo.memory_cost metainfo_vector = []
metainfo_vector.append(metainfo) for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
# attach metainfos to the handler strategy.compute_cost = metainfo.compute_cost
setattr(self, "metainfo_vector", metainfo_vector) strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector return self.strategies_vector

Loading…
Cancel
Save