[autoparallel] adapt autoparallel tests with latest api (#2626)

pull/2664/head
YuliangLiu0306 2023-02-08 15:02:12 +08:00 committed by GitHub
parent c375563653
commit cb3d1bef62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 59 additions and 583 deletions

View File

@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
strategies.append(self.split_rhs_space_both_contract(1, 0))
# RR= RS x SR
# strategies.append(self.recompute_split_both_contract(0))
# strategies.append(self.recompute_split_both_contract(1))
strategies.append(self.recompute_split_both_contract(0))
strategies.append(self.recompute_split_both_contract(1))
# # RS = RR x RS
# strategies.append(self.split_rhs_space_only(0))
# strategies.append(self.split_rhs_space_only(1))
# RS = RR x RS
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))
# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# # RR = RR x RR
# strategies.append(self.non_split())
# RR = RR x RR
strategies.append(self.non_split())
return strategies

View File

@ -62,9 +62,6 @@ class CostGraph:
else:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
# children_nodes = [node for node in strategies_vector.successor_nodes]
parent_nodes = []
children_nodes = []

View File

@ -4,21 +4,11 @@ import pytest
import torch
import torch.multiprocessing as mp
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
@ -63,42 +53,9 @@ def check_linear_module(rank, world_size, port):
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')})
# def forward(self, x : torch.Tensor):
# linear_weight = self.linear.weight
# linear_bias = self.linear.bias
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)
gm.recompile()
node_list = list(graph.nodes)
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
linear_node = node_list[3]
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
meta_args = {'x': torch.rand(4, 4).to('meta')}
gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh)
output = gm(input)
assert_close(output, output_compare)
@ -113,47 +70,9 @@ def check_conv_module(rank, world_size, port):
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
# return mul
graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')})
# def forward(self, x : torch.Tensor):
# conv_weight = self.conv.weight
# conv_bias = self.conv.bias
# conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
# view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
# add = conv2d + view; conv2d = view = None
# mul = add * 2; add = None
# return mul
gm = ColoGraphModule(model, graph)
gm.recompile()
node_list = list(graph.nodes)
conv_node = node_list[3]
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
meta_args = {'x': torch.rand(4, 3, 64, 64).to('meta')}
gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh)
output = gm(input)
assert_close(output, output_compare)

View File

@ -1,131 +0,0 @@
import copy
import random
from functools import partial
from time import time
from typing import Dict, Optional, Tuple, Union
import numpy as np
import psutil
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from torch.profiler import ProfilerActivity, profile, record_function, schedule, tensorboard_trace_handler
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch, launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2LMHeadModel, GPTLMLoss
BATCH_SIZE = 32
SEQ_LENGTH = 256
HIDDEN_DIM = 16384
NUM_HEADS = 128
NUM_LAYERS = 4
VOCAB_SIZE = 50257
NUM_STEPS = 10
FP16 = True
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def get_tflops(model_numel, batch_size, seq_len, step_time):
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
# Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def main():
disable_existing_loggers()
launch_from_torch(config={})
logger = get_dist_logger()
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
if FP16:
model = GPT2LMHeadModel(config=config).half().to('cuda')
else:
model = GPT2LMHeadModel(config=config).to('cuda')
global_numel = sum([p.numel() for p in model.parameters()])
meta_input_sample = {
'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
}
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
gm = initialize_model(model, meta_input_sample, device_mesh)
# build criterion
criterion = GPTLMLoss()
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH)
torch.cuda.synchronize()
model.train()
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
# schedule=schedule(wait=1, warmup=2, active=2),
# on_trace_ready=tensorboard_trace_handler(f'log/dummy_data/bs128_seq128_new'),
# record_shapes=True,
# profile_memory=True) as prof:
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as prof:
for n in range(10):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = gm(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
loss.backward()
optimizer.step()
# prof.step()
torch.cuda.synchronize()
step_time = time() - start
logger.info(
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0])
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10))
torch.cuda.synchronize()
if __name__ == '__main__':
main()

View File

@ -1,32 +1,27 @@
import copy
import random
from functools import partial
from typing import Dict, Optional, Tuple, Union
from typing import Dict
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
from colossalai.auto_parallel.tensor_shard.initialize import (
ModuleWrapper,
build_strategy_constructor,
solve_solution,
transform_to_sharded_model,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
from colossalai.tensor.shape_consistency import to_global
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
@ -49,6 +44,7 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor
best_sharding_spec_dict: Dict[str, ShardingSpec]):
for name, param in module.named_parameters():
param_grad = param.grad
name = name.replace('module.', '')
origin_param_grad = origin_param_dict[name].grad
atoms = name.split('.')
new_name = '_'.join(atoms)
@ -115,30 +111,17 @@ def check_attention_layer(rank, model_cls, world_size, port):
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
strategies_constructor = build_strategy_constructor(graph, device_mesh)
solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
gm = ModuleWrapper(gm, *sharding_spec_dicts)
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
best_sharding_spec_dict = {}
for index, node in enumerate(nodes):
@ -149,7 +132,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
origin_output = test_model(*test_input_sample)
torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(cpu_rng_state)
output = gm(*input_sample, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
output = gm(*input_sample)
assert_close(output, origin_output, rtol=1e-03, atol=1e-03)
#*******************backward starting*******************
@ -174,16 +157,15 @@ def check_attention_layer(rank, model_cls, world_size, port):
#*******************strategy selected*******************
if rank == 0:
print("*******************strategy selected*******************")
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
print(node.name, node.strategies_vector[solution[index]].name)
computation_cost += node.strategies_vector[solution[index]].compute_cost.total
communication_cost += node.strategies_vector[solution[index]].communication_cost.total
node_memory_cost = node.strategies_vector[solution[index]].memory_cost.total
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost.activation + node_memory_cost.parameter

View File

@ -57,7 +57,7 @@ def mem_test_for_node_strategy(rank: int,
output_key]
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh)
gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()
gm: GraphModule

View File

@ -1,270 +0,0 @@
import copy
from copy import deepcopy
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.fx import GraphModule
from torchvision.models import resnet34, resnet50
from colossalai import device
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.constants import *
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
seed = 128
cudnn_benchmark = False
cudnn_deterministic = True
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample=None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer=None,
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.relu(out)
return out
def check_apply_bottleneck(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
input = torch.rand(4, 4, 4, 4).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
tracer = ColoTracer()
model = Bottleneck(4, 4, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
test_model = copy.deepcopy(model)
test_input = copy.deepcopy(input)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
# %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
# %conv2 : [#users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
# %bn2 : [#users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
# %relu_1 : [#users=1] = call_module[target=relu](args = (%bn2,), kwargs = {})
# %conv3 : [#users=1] = call_module[target=conv3](args = (%relu_1,), kwargs = {})
# %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
# %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
# return relu_2
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
print(solution)
for index, node in enumerate(graph.nodes):
print(node.name, node.strategies_vector[solution[index]].name)
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code
cuda_rng_state = torch.cuda.get_rng_state()
origin_output = test_model(test_input)
torch.cuda.set_rng_state(cuda_rng_state)
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
assert output.shape == origin_output.shape
assert_close(output, origin_output, rtol=1e-03, atol=1e-05)
print("*******************backward starting*******************")
cuda_rng_state = torch.cuda.get_rng_state()
output.sum().backward()
torch.cuda.set_rng_state(cuda_rng_state)
origin_output.sum().backward()
if rank == 0:
print(
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum()}"
)
print(
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}"
)
print(
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
)
print(
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
)
print(
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 0, 1)).abs().sum()}"
)
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum())
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
if rank == 1:
print(
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum()}"
)
print(
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}"
)
print(
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
)
print(
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
)
print(
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 1, 1)).abs().sum()}"
)
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum())
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
if rank == 2:
print(
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum()}"
)
print(
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}"
)
print(
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
)
print(
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
)
print(
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 2, 1)).abs().sum()}"
)
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum())
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
if rank == 3:
print(
f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum()}"
)
print(
f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}"
)
print(
f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
)
print(
f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
)
print(
f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 3, 1)).abs().sum()}"
)
print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum())
assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_apply():
world_size = 4
run_func = partial(check_apply_bottleneck, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_apply()

View File

@ -5,19 +5,9 @@ import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver import (
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use
@ -41,41 +31,22 @@ def check_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
input = torch.rand(4, 4, 4, 4).cuda()
test_input = copy.deepcopy(input)
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
model = ConvModel(4, 4).cuda()
test_model = copy.deepcopy(model)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
meta_args = {'x': torch.rand(4, 4, 4, 4).to('meta')}
gm = initialize_model(model, meta_args, device_mesh)
tracer = ColoTracer()
model = ConvModel(4, 4).cuda()
test_model = copy.deepcopy(model)
test_input = copy.deepcopy(input)
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
gm = runtime_apply_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
output = gm(input)
origin_output = test_model(test_input)
assert output.equal(origin_output)
origin_loss = origin_output.sum()
@ -84,13 +55,21 @@ def check_apply(rank, world_size, port):
origin_loss.backward()
loss.backward()
grad_0 = test_model.conv.weight.grad.narrow(0, 0, 2)
grad_1 = test_model.conv.weight.grad.narrow(0, 2, 2)
grad_0 = test_model.conv.weight.grad.narrow(0, 0, 1)
grad_1 = test_model.conv.weight.grad.narrow(0, 1, 1)
grad_2 = test_model.conv.weight.grad.narrow(0, 2, 1)
grad_3 = test_model.conv.weight.grad.narrow(0, 3, 1)
if rank in (0, 1):
assert_close(gm.conv.weight.grad.data, grad_0.data)
elif rank in (2, 3):
assert_close(gm.conv.weight.grad.data, grad_1.data)
if rank == 0:
assert_close(gm.module.conv.weight.grad.data, grad_0.data)
elif rank == 1:
assert_close(gm.module.conv.weight.grad.data, grad_1.data)
elif rank == 2:
assert_close(gm.module.conv.weight.grad.data, grad_2.data)
elif rank == 3:
assert_close(gm.module.conv.weight.grad.data, grad_3.data)
else:
raise ValueError(f'rank {rank} does not exist.')
# skip this test due to pulp not installed in CI environment