Browse Source

[autoparallel] bypass MetaInfo when unavailable and modify BCAST_FUNC_OP metainfo (#2293)

* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline

* [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop

* [autoparallel] specifycomm nodes' memory cost in construct chain

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] fix wrong runtime apply calculation

* [autoparallel] bypass metainfo when available and modify BCAST_FUNC_OP
pull/2258/head
Boyuan Yao 2 years ago committed by GitHub
parent
commit
b904748210
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 11
      colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py
  2. 46
      colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py

11
colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py

@ -24,26 +24,25 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
"""
input_op_data, other_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
input_op_data = [arg for arg in args if arg.type != OperationDataType.OUTPUT]
output_op_data = next(filter(lambda arg: arg.type == OperationDataType.OUTPUT, args))
# construct forward args for flop mapping
fwd_in_args = [input_op_data.data, other_op_data.data]
fwd_in_args = [opdata.data for opdata in input_op_data]
fwd_out_args = [output_op_data.data]
# calculate cost
# calculate compute cost
# NOTE: we set bwd_compute_cost two times of fwd_compute_cost in this case
fwd_compute_cost = flop_mapping[torch.ops.aten._adaptive_avg_pool2d.default](fwd_in_args, fwd_out_args)
fwd_compute_cost = flop_mapping[torch.ops.aten.add.Tensor](fwd_in_args, fwd_out_args)
bwd_compute_cost = fwd_compute_cost * 2
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
param_mem_cost = activation_size(
[arg.data for arg in [input_op_data, other_op_data] if arg.type == OperationDataType.PARAM])
param_mem_cost = activation_size([arg.data for arg in input_op_data if arg.type == OperationDataType.PARAM])
fwd_mem_cost = MemoryCost(
activation=activation_size([input_op_data.data, output_op_data.data]),
activation=activation_size(output_op_data.data),
parameter=param_mem_cost,
)
bwd_mem_cost = MemoryCost(

46
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py

@ -4,7 +4,7 @@ from typing import Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@ -234,15 +234,19 @@ class MetaInfoNodeHandler(NodeHandler):
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector
@ -281,14 +285,18 @@ class MetaInfoModuleHandler(ModuleHandler):
"""
super().register_strategy(compute_resharding_cost=compute_resharding_cost)
target = self.get_target_function()
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
# Currently we haven't patched all the torch functions and modules, so if the target
# is not patched, we will use the default cost model to compute the cost.
# TODO: patch all torch functions and modules to make it clean
if meta_register.has(target.__class__) or meta_register.has(target):
metainfo_vector = []
for strategy in self.strategies_vector:
metainfo = MetaInfo(strategy, target)
strategy.compute_cost = metainfo.compute_cost
strategy.memory_cost = metainfo.memory_cost
metainfo_vector.append(metainfo)
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
return self.strategies_vector

Loading…
Cancel
Save