mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] resnet block runtime apply (#1709)
* [autoparallel] resnet block runtime apply * seperate buffer and parameter in MemoryCost * polish code * add comments and todos * fix test issuepull/1713/head
parent
b0a23dc4fc
commit
845ff4a47a
|
@ -36,7 +36,30 @@ class BatchNormModuleHandler(ModuleHandler):
|
|||
logical_shape=self.named_parameters['weight'].shape)
|
||||
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
|
||||
|
||||
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
|
||||
physical_running_mean_operand = OperationData(name="running_mean",
|
||||
type=OperationDataType.BUFFER,
|
||||
data=self.named_buffers['running_mean'],
|
||||
logical_shape=self.named_buffers['running_mean'].shape)
|
||||
|
||||
physical_running_var_operand = OperationData(name="running_var",
|
||||
type=OperationDataType.BUFFER,
|
||||
data=self.named_buffers['running_var'],
|
||||
logical_shape=self.named_buffers['running_var'].shape)
|
||||
|
||||
physical_num_batches_tracked_operand = OperationData(
|
||||
name="num_batches_tracked",
|
||||
type=OperationDataType.BUFFER,
|
||||
data=self.named_buffers['num_batches_tracked'],
|
||||
logical_shape=self.named_buffers['num_batches_tracked'].shape)
|
||||
|
||||
mapping = {
|
||||
"input": physical_input_operand,
|
||||
"other": physical_other_operand,
|
||||
"output": physical_output,
|
||||
"running_mean": physical_running_mean_operand,
|
||||
"running_var": physical_running_var_operand,
|
||||
"num_batches_tracked": physical_num_batches_tracked_operand
|
||||
}
|
||||
|
||||
if self.named_parameters['bias'] is not None:
|
||||
physical_bias_operand = OperationData(name="bias",
|
||||
|
|
|
@ -146,7 +146,10 @@ class ModuleHandler(NodeHandler):
|
|||
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
|
||||
module = self.node.graph.owning_module.get_submodule(self.node.target)
|
||||
named_parameters = list(module.named_parameters(recurse=False))
|
||||
named_buffers = list(module.named_buffers(recurse=False))
|
||||
# convert named parameters from list to dict
|
||||
named_parameters = {k: v for k, v in named_parameters}
|
||||
named_buffers = {k: v for k, v in named_buffers}
|
||||
self.module = module
|
||||
self.named_parameters = named_parameters
|
||||
self.named_buffers = named_buffers
|
||||
|
|
|
@ -13,6 +13,7 @@ __all__ = ['ReshapeHandler']
|
|||
@operator_registry.register(torch.reshape)
|
||||
@operator_registry.register(torch.flatten)
|
||||
@operator_registry.register(torch.Tensor.permute)
|
||||
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
|
||||
class ReshapeHandler(NodeHandler):
|
||||
"""
|
||||
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||
|
|
|
@ -64,7 +64,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
forward_size_mapping = {
|
||||
'input': self._compute_size_in_bytes(strategy, "input"),
|
||||
'other': self._compute_size_in_bytes(strategy, "other"),
|
||||
'output': self._compute_size_in_bytes(strategy, "output")
|
||||
'output': self._compute_size_in_bytes(strategy, "output"),
|
||||
'running_mean': self._compute_size_in_bytes(strategy, "running_mean"),
|
||||
'running_var': self._compute_size_in_bytes(strategy, "running_var"),
|
||||
}
|
||||
|
||||
if self.has_bias:
|
||||
|
@ -75,24 +77,27 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
backward_size_mapping.pop("output")
|
||||
# compute fwd cost incurred
|
||||
# fwd_cost = input + other + bias + output
|
||||
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
|
||||
fwd_activation_cost = sum(
|
||||
[v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
|
||||
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
|
||||
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
|
||||
fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)])
|
||||
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost)
|
||||
|
||||
# compute bwd cost incurred
|
||||
# bwd_cost = input_grad + other_grad + bias_grad
|
||||
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
|
||||
bwd_activation_cost = sum(
|
||||
[v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
|
||||
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
|
||||
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
|
||||
|
||||
# compute total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
|
||||
parameter=fwd_parameter_cost + bwd_parameter_cost)
|
||||
parameter=fwd_parameter_cost + bwd_parameter_cost,
|
||||
buffer=fwd_buffer_cost)
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def split_input_channel(self, mesh_dim_0):
|
||||
strategy_list = []
|
||||
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
|
||||
dim_partition_dict_mapping = {
|
||||
"input": {
|
||||
|
@ -104,6 +109,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
"output": {
|
||||
1: [mesh_dim_0]
|
||||
},
|
||||
"running_mean": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"running_var": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"num_batches_tracked": {},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
|
||||
|
@ -128,6 +140,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
"output": {
|
||||
1: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"running_mean": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"running_var": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"num_batches_tracked": {},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
|
||||
|
@ -146,6 +165,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
"input": {},
|
||||
"other": {},
|
||||
"output": {},
|
||||
"running_mean": {},
|
||||
"running_var": {},
|
||||
"num_batches_tracked": {},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {}
|
||||
|
@ -168,6 +190,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
"output": {
|
||||
0: [mesh_dim_0]
|
||||
},
|
||||
"running_mean": {},
|
||||
"running_var": {},
|
||||
"num_batches_tracked": {},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {}
|
||||
|
@ -199,6 +224,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
"output": {
|
||||
0: [mesh_dim_0, mesh_dim_1]
|
||||
},
|
||||
"running_mean": {},
|
||||
"running_var": {},
|
||||
"num_batches_tracked": {},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {}
|
||||
|
@ -234,6 +262,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
0: [mesh_dim_0],
|
||||
1: [mesh_dim_1],
|
||||
},
|
||||
"running_mean": {
|
||||
0: [mesh_dim_1],
|
||||
},
|
||||
"running_var": {
|
||||
0: [mesh_dim_1],
|
||||
},
|
||||
"num_batches_tracked": {},
|
||||
}
|
||||
if self.has_bias:
|
||||
dim_partition_dict_mapping["bias"] = {
|
||||
|
@ -273,16 +308,22 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|||
# RS01 = RS01 x S01
|
||||
strategy_list.append(self.split_input_channel_1d(0, 1))
|
||||
|
||||
# The strategies with SYNC_BN are temporarily commented,
|
||||
# because it requires some additional passes to keep runtime
|
||||
# computation correctness.
|
||||
|
||||
# TODO: The strategies below should be uncommented after runtime
|
||||
# passes ready.
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
strategy_list.append(self.split_input_batch(0))
|
||||
strategy_list.append(self.split_input_batch(1))
|
||||
# strategy_list.append(self.split_input_batch(0))
|
||||
# strategy_list.append(self.split_input_batch(1))
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
strategy_list.append(self.split_input_both_dim(0, 1))
|
||||
strategy_list.append(self.split_input_both_dim(1, 0))
|
||||
# strategy_list.append(self.split_input_both_dim(0, 1))
|
||||
# strategy_list.append(self.split_input_both_dim(1, 0))
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
strategy_list.append(self.split_input_batch_1d(0, 1))
|
||||
# strategy_list.append(self.split_input_batch_1d(0, 1))
|
||||
|
||||
for strategy in strategy_list:
|
||||
self.update_communication_cost(strategy)
|
||||
|
|
|
@ -35,6 +35,10 @@ class StrategyGenerator(ABC):
|
|||
other_data = self.op_data[op_data_name]
|
||||
return other_data.type == OperationDataType.PARAM
|
||||
|
||||
def is_buffer(self, op_data_name):
|
||||
other_data = self.op_data[op_data_name]
|
||||
return other_data.type == OperationDataType.BUFFER
|
||||
|
||||
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
|
||||
communication_action_mapping: Dict[str, CommSpec]):
|
||||
"""
|
||||
|
|
|
@ -20,7 +20,8 @@ class OperationDataType(Enum):
|
|||
INPUT = 0
|
||||
ARG = 1
|
||||
PARAM = 2
|
||||
OUTPUT = 3
|
||||
BUFFER = 3
|
||||
OUTPUT = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -80,6 +81,7 @@ class MemoryCost:
|
|||
"""
|
||||
activation: int = 0
|
||||
parameter: int = 0
|
||||
buffer: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||
import torch
|
||||
|
||||
|
||||
class CostGraph:
|
||||
|
@ -51,7 +52,6 @@ class CostGraph:
|
|||
if src_node not in self.nodes:
|
||||
continue
|
||||
node_pair = (src_node, dst_node)
|
||||
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||
edge_cost = {}
|
||||
for i in range(len(strategies_vector)):
|
||||
for j in range(len(src_node.strategies_vector)):
|
||||
|
@ -62,10 +62,12 @@ class CostGraph:
|
|||
edge_cost[(j, i)] = resharding_cost_item.total
|
||||
self.edge_costs[node_pair] = edge_cost
|
||||
# add parents and children attribute to node
|
||||
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
||||
setattr(dst_node, 'children', strategies_vector.successor_nodes)
|
||||
self._remove_invalid_node(dst_node, 'parents')
|
||||
self._remove_invalid_node(dst_node, 'children')
|
||||
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
|
||||
children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
setattr(dst_node, 'parents', parent_nodes)
|
||||
setattr(dst_node, 'children', children_nodes)
|
||||
# self._remove_invalid_node(dst_node, 'parents')
|
||||
# self._remove_invalid_node(dst_node, 'children')
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
|
|
|
@ -169,10 +169,7 @@ class Solver:
|
|||
else:
|
||||
communication_costs.append(origin_communication_cost)
|
||||
memory_costs.append(memory_cost)
|
||||
# if isinstance(memory_cost, tuple):
|
||||
# memory_costs.append(memory_cost[0])
|
||||
# else:
|
||||
# memory_costs.append(memory_cost)
|
||||
|
||||
compute_costs = np.array(compute_costs)
|
||||
communication_costs = np.array(communication_costs)
|
||||
memory_costs = np.array(memory_costs)
|
||||
|
|
|
@ -36,16 +36,19 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
|
|||
for name, param in target_module.named_parameters():
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
target_weight_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
apply(param, target_weight_sharding_spec)
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
apply(param, target_sharding_spec)
|
||||
|
||||
for name, buffer in target_module.named_buffers():
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
|
||||
setattr(buffer, 'sharding_spec', origin_sharding_spec)
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
apply(buffer, target_sharding_spec)
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
for index, node in enumerate(nodes):
|
||||
target_sharding_specs = []
|
||||
if node.name == 'bn1':
|
||||
print(node.strategies_vector.successor_nodes)
|
||||
assert False
|
||||
for user_node in node.strategies_vector.successor_nodes:
|
||||
# node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
||||
# target_sharding_spec = user_node.best_strategy.input_shardings[node_index]
|
||||
|
|
|
@ -0,0 +1,172 @@
|
|||
from functools import partial
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from colossalai import device
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from torchvision.models import resnet34, resnet50
|
||||
from colossalai.auto_parallel.tensor_shard.constants import *
|
||||
from colossalai.testing import assert_close_loose, assert_close
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
seed = 128
|
||||
cudnn_benchmark = False
|
||||
cudnn_deterministic = True
|
||||
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion: int = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample=None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.0)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def check_apply_bottleneck(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
input = torch.rand(256, 64, 64, 64).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
|
||||
entire_shape = torch.Size((4, 4, 8, 8))
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = Bottleneck(64, 64, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
|
||||
# %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
|
||||
# %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
|
||||
# %conv2 : [#users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
|
||||
# %bn2 : [#users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
|
||||
# %relu_1 : [#users=1] = call_module[target=relu](args = (%bn2,), kwargs = {})
|
||||
# %conv3 : [#users=1] = call_module[target=conv3](args = (%relu_1,), kwargs = {})
|
||||
# %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
|
||||
# %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
|
||||
# return relu_2
|
||||
input_sample = {'x': torch.rand(256, 64, 224, 224).to('meta')}
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
origin_output = model(input)
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
solver_options = SolverOptions(fast=True)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
cost_graph.simplify_graph()
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
|
||||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
print(solution)
|
||||
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
|
||||
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
shape_consistency_pass(gm)
|
||||
gm.recompile()
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
# TODO: wrap the gm to avoid the influence of the user training code
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict)
|
||||
assert output.shape == origin_output.shape
|
||||
assert output.equal(origin_output)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_apply():
|
||||
world_size = 4
|
||||
run_func = partial(check_apply_bottleneck, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_apply()
|
|
@ -7,8 +7,10 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.fx.tracer.meta_patch.patched_module import linear
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skip("skip due to passes not ready")
|
||||
def test_bn_module_handler():
|
||||
model = nn.Sequential(nn.BatchNorm2d(16).to('meta'))
|
||||
tracer = ColoTracer()
|
||||
|
|
Loading…
Reference in New Issue