mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] bypass MetaInfo when unavailable and modify BCAST_FUNC_OP metainfo (#2293)
* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline * [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop * [autoparallel] specifycomm nodes' memory cost in construct chain * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculation * [autoparallel] fix wrong runtime apply calculation * [autoparallel] bypass metainfo when available and modify BCAST_FUNC_OPpull/2258/head
parent
8ea50d999e
commit
b904748210
|
@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
|
||||||
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
|
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
|
||||||
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
|
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
|
||||||
|
|
||||||
# construct forward args for flop mapping
|
# construct forward args for flop mapping
|
||||||
fwd_in_args = [input_op_data.data, other_op_data.data]
|
fwd_in_args = [opdata.data for opdata in input_op_data]
|
||||||
fwd_out_args = [output_op_data.data]
|
fwd_out_args = [output_op_data.data]
|
||||||
|
|
||||||
# calculate cost
|
# calculate cost
|
||||||
|
|
||||||
# calculate compute cost
|
# calculate compute cost
|
||||||
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
|
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
|
||||||
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
|
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
|
||||||
bwd_compute_cost = fwd_compute_cost * 2
|
bwd_compute_cost = fwd_compute_cost * 2
|
||||||
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
|
||||||
|
|
||||||
# calculate memory cost
|
# calculate memory cost
|
||||||
param_mem_cost = activation_size(
|
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
|
||||||
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
|
|
||||||
fwd_mem_cost = MemoryCost(
|
fwd_mem_cost = MemoryCost(
|
||||||
activation=activation_size([input_op_data.data, output_op_data.data]),
|
activation=activation_size(output_op_data.data),
|
||||||
parameter=param_mem_cost,
|
parameter=param_mem_cost,
|
||||||
)
|
)
|
||||||
bwd_mem_cost = MemoryCost(
|
bwd_mem_cost = MemoryCost(
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
|
|
||||||
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
|
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
OperationData,
|
OperationData,
|
||||||
OperationDataType,
|
OperationDataType,
|
||||||
|
@ -234,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler):
|
||||||
"""
|
"""
|
||||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
target = self.get_target_function()
|
target = self.get_target_function()
|
||||||
metainfo_vector = []
|
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||||
for strategy in self.strategies_vector:
|
# is not patched, we will use the default cost model to compute the cost.
|
||||||
metainfo = MetaInfo(strategy, target)
|
# TODO: patch all torch functions and modules to make it clean
|
||||||
strategy.compute_cost = metainfo.compute_cost
|
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||||
strategy.memory_cost = metainfo.memory_cost
|
metainfo_vector = []
|
||||||
metainfo_vector.append(metainfo)
|
for strategy in self.strategies_vector:
|
||||||
|
metainfo = MetaInfo(strategy, target)
|
||||||
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
|
metainfo_vector.append(metainfo)
|
||||||
|
|
||||||
# attach metainfos to the handler
|
# attach metainfos to the handler
|
||||||
setattr(self, "metainfo_vector", metainfo_vector)
|
setattr(self, "metainfo_vector", metainfo_vector)
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
||||||
|
@ -281,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler):
|
||||||
"""
|
"""
|
||||||
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
|
||||||
target = self.get_target_function()
|
target = self.get_target_function()
|
||||||
metainfo_vector = []
|
# Currently we haven't patched all the torch functions and modules, so if the target
|
||||||
for strategy in self.strategies_vector:
|
# is not patched, we will use the default cost model to compute the cost.
|
||||||
metainfo = MetaInfo(strategy, target)
|
# TODO: patch all torch functions and modules to make it clean
|
||||||
strategy.compute_cost = metainfo.compute_cost
|
if meta_register.has(target.__class__) or meta_register.has(target):
|
||||||
strategy.memory_cost = metainfo.memory_cost
|
metainfo_vector = []
|
||||||
metainfo_vector.append(metainfo)
|
for strategy in self.strategies_vector:
|
||||||
|
metainfo = MetaInfo(strategy, target)
|
||||||
|
strategy.compute_cost = metainfo.compute_cost
|
||||||
|
strategy.memory_cost = metainfo.memory_cost
|
||||||
|
metainfo_vector.append(metainfo)
|
||||||
|
|
||||||
# attach metainfos to the handler
|
# attach metainfos to the handler
|
||||||
setattr(self, "metainfo_vector", metainfo_vector)
|
setattr(self, "metainfo_vector", metainfo_vector)
|
||||||
|
|
||||||
return self.strategies_vector
|
return self.strategies_vector
|
||||||
|
|
Loading…
Reference in New Issue