mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] shard param and buffer as expected (#1753)
* [autoparallel] shard param and buffer as expected * fix unit test issuepull/1751/head
parent
cdb7d5e7d2
commit
980ed21723
|
@ -1,4 +1,5 @@
|
||||||
import builtins
|
import builtins
|
||||||
|
import copy
|
||||||
import operator
|
import operator
|
||||||
from ast import NodeTransformer
|
from ast import NodeTransformer
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -11,34 +12,13 @@ from torch.fx.node import Node
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.fx.passes.split_module import split_module
|
from colossalai.fx.passes.split_module import split_module
|
||||||
from colossalai.tensor.comm_spec import CommSpec, _all_reduce
|
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec, _all_reduce, pattern_to_func_dict
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||||
|
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
||||||
class ConsistencyApply(torch.autograd.Function):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, node, origin_sharding_spec, target_sharding_spec):
|
|
||||||
ctx.origin_sharding_spec = origin_sharding_spec
|
|
||||||
ctx.target_sharding_spec = target_sharding_spec
|
|
||||||
return shape_consistency_manager.apply_for_autoparallel_runtime(node, ctx.origin_sharding_spec,
|
|
||||||
ctx.target_sharding_spec)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, node_grad):
|
|
||||||
return shape_consistency_manager.apply_for_autoparallel_runtime(
|
|
||||||
node_grad, ctx.target_sharding_spec, ctx.origin_sharding_spec), None, None, None, None
|
|
||||||
|
|
||||||
|
|
||||||
def runtime_apply_for_leaf_node(node, origin_dict, input_dict, node_index, user_node_index):
|
|
||||||
origin_sharding_spec = origin_dict[node_index]
|
|
||||||
target_sharding_spec = input_dict[node_index][user_node_index]
|
|
||||||
return ConsistencyApply.apply(node, origin_sharding_spec, target_sharding_spec)
|
|
||||||
|
|
||||||
|
|
||||||
def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index):
|
def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index):
|
||||||
origin_sharding_spec = origin_dict[node_index]
|
origin_sharding_spec = origin_dict[node_index]
|
||||||
target_sharding_spec = input_dict[node_index][user_node_index]
|
target_sharding_spec = input_dict[node_index][user_node_index]
|
||||||
|
@ -53,7 +33,7 @@ def runtime_comm_spec_apply(tensor, comm_actions_dict, node_index, op_data):
|
||||||
else:
|
else:
|
||||||
origin_sharding_spec = comm_action.comm_spec['src_spec']
|
origin_sharding_spec = comm_action.comm_spec['src_spec']
|
||||||
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
|
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
|
||||||
rst = ConsistencyApply.apply(tensor, origin_sharding_spec, tgt_sharding_spec)
|
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
|
||||||
return rst
|
return rst
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,12 +55,17 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
|
||||||
if node.op == 'call_module':
|
if node.op == 'call_module':
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
for name, param in target_module.named_parameters():
|
for name, param in target_module.named_parameters():
|
||||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
|
||||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
|
||||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||||
shape_consistency_manager.apply(param, target_sharding_spec)
|
if target_sharding_spec.dim_partition_dict != {}:
|
||||||
|
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||||
|
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||||
|
param_sharded = torch.nn.Parameter(
|
||||||
|
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||||
|
target_sharding_spec).detach().clone())
|
||||||
|
else:
|
||||||
|
param_sharded = param
|
||||||
|
setattr(target_module, name, param_sharded)
|
||||||
comm_actions = node.best_strategy.communication_actions
|
comm_actions = node.best_strategy.communication_actions
|
||||||
|
|
||||||
for operation_data, comm_action in comm_actions.items():
|
for operation_data, comm_action in comm_actions.items():
|
||||||
comm_spec_to_use = comm_action.comm_spec
|
comm_spec_to_use = comm_action.comm_spec
|
||||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||||
|
@ -88,13 +73,18 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
|
||||||
def hook_fn(grad):
|
def hook_fn(grad):
|
||||||
_all_reduce(grad, comm_spec_to_use)
|
_all_reduce(grad, comm_spec_to_use)
|
||||||
|
|
||||||
param.register_hook(hook_fn)
|
param_sharded.register_hook(hook_fn)
|
||||||
|
|
||||||
|
sharded_buffer_dict = {}
|
||||||
for name, buffer in target_module.named_buffers():
|
for name, buffer in target_module.named_buffers():
|
||||||
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
|
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
|
||||||
setattr(buffer, 'sharding_spec', origin_sharding_spec)
|
setattr(buffer, 'sharding_spec', origin_sharding_spec)
|
||||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||||
shape_consistency_manager.apply(buffer, target_sharding_spec)
|
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
|
||||||
|
sharded_buffer_dict[name] = buffer_sharded
|
||||||
|
|
||||||
|
for name, buffer_sharded in sharded_buffer_dict.items():
|
||||||
|
setattr(target_module, name, buffer_sharded.detach().clone())
|
||||||
|
|
||||||
# the dict to get input sharding specs of user node
|
# the dict to get input sharding specs of user node
|
||||||
sharding_spec_convert_dict = {}
|
sharding_spec_convert_dict = {}
|
||||||
|
@ -157,19 +147,11 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
|
||||||
|
|
||||||
for user_node in node.strategies_vector.successor_nodes:
|
for user_node in node.strategies_vector.successor_nodes:
|
||||||
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
||||||
if user_node.op != "output":
|
with mod_graph.inserting_before(user_node):
|
||||||
with mod_graph.inserting_before(user_node):
|
shape_consistency_node = mod_graph.create_node('call_function',
|
||||||
shape_consistency_node = mod_graph.create_node('call_function',
|
runtime_apply,
|
||||||
runtime_apply,
|
args=(node, origin_dict_node, input_dict_node,
|
||||||
args=(node, origin_dict_node, input_dict_node,
|
node_to_index_dict[node], user_node_index))
|
||||||
node_to_index_dict[node], user_node_index))
|
|
||||||
else:
|
|
||||||
# we need to call an autograd.Function for leaf node
|
|
||||||
with mod_graph.inserting_before(user_node):
|
|
||||||
shape_consistency_node = mod_graph.create_node('call_function',
|
|
||||||
runtime_apply_for_leaf_node,
|
|
||||||
args=(node, origin_dict_node, input_dict_node,
|
|
||||||
node_to_index_dict[node], user_node_index))
|
|
||||||
|
|
||||||
origin_index_args = user_node.args.index(node)
|
origin_index_args = user_node.args.index(node)
|
||||||
new_args = list(user_node.args)
|
new_args = list(user_node.args)
|
||||||
|
@ -179,21 +161,29 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
|
||||||
comm_actions = node.best_strategy.communication_actions
|
comm_actions = node.best_strategy.communication_actions
|
||||||
for op_data, comm_action in comm_actions.items():
|
for op_data, comm_action in comm_actions.items():
|
||||||
comm_object = node.args[comm_action.arg_index]
|
comm_object = node.args[comm_action.arg_index]
|
||||||
if op_data.type == OperationDataType.ARG:
|
if op_data.type == OperationDataType.PARAM:
|
||||||
if comm_action.comm_type == CommType.BEFORE:
|
continue
|
||||||
with mod_graph.inserting_before(node):
|
if comm_action.comm_type == CommType.BEFORE:
|
||||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
with mod_graph.inserting_before(node):
|
||||||
runtime_comm_spec_apply,
|
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||||
args=(comm_object, comm_actions_dict_node,
|
runtime_comm_spec_apply,
|
||||||
node_to_index_dict[node], op_data.name))
|
args=(comm_object, comm_actions_dict_node,
|
||||||
elif comm_action.comm_type == CommType.AFTER:
|
node_to_index_dict[node], op_data.name))
|
||||||
with mod_graph.inserting_after(node):
|
new_args = list(node.args)
|
||||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
new_args[comm_action.arg_index] = comm_spec_apply_node
|
||||||
runtime_comm_spec_apply,
|
node.args = new_args
|
||||||
args=(comm_object, comm_actions_dict_node,
|
elif comm_action.comm_type == CommType.AFTER:
|
||||||
node_to_index_dict[node], op_data.name))
|
with mod_graph.inserting_after(node):
|
||||||
|
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||||
|
runtime_comm_spec_apply,
|
||||||
|
args=(node, comm_actions_dict_node,
|
||||||
|
node_to_index_dict[node], op_data.name))
|
||||||
|
user_list = list(node.users.keys())
|
||||||
|
for user in user_list:
|
||||||
|
if user == comm_spec_apply_node:
|
||||||
|
continue
|
||||||
|
new_args = list(user.args)
|
||||||
|
new_args[new_args.index(node)] = comm_spec_apply_node
|
||||||
|
user.args = tuple(new_args)
|
||||||
# TODO: consider other OperationDataType, such as OperationDataType.OUTPUT
|
# TODO: consider other OperationDataType, such as OperationDataType.OUTPUT
|
||||||
new_args = list(node.args)
|
|
||||||
new_args[comm_action.arg_index] = comm_spec_apply_node
|
|
||||||
node.args = new_args
|
|
||||||
return gm
|
return gm
|
||||||
|
|
|
@ -345,9 +345,9 @@ class CommSpec:
|
||||||
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
|
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
|
||||||
'''
|
'''
|
||||||
if self.comm_pattern in pattern_to_func_dict:
|
if self.comm_pattern in pattern_to_func_dict:
|
||||||
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
|
tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
|
||||||
else:
|
else:
|
||||||
tensor.data = tensor
|
tensor = tensor
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -511,13 +511,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||||
'''
|
'''
|
||||||
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
|
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
|
||||||
for comm_spec in comm_action_sequence:
|
for comm_spec in comm_action_sequence:
|
||||||
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
|
tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
|
||||||
tensor_with_sharding_spec.sharding_spec = target_spec
|
tensor_with_sharding_spec.sharding_spec = target_spec
|
||||||
return tensor_with_sharding_spec
|
return tensor_with_sharding_spec
|
||||||
|
|
||||||
def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec):
|
def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec):
|
||||||
_, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec)
|
_, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec)
|
||||||
for comm_spec in comm_action_sequence:
|
for comm_spec in comm_action_sequence:
|
||||||
comm_spec.covert_spec_to_action(tensor)
|
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||||
tensor.sharding_spec = target_spec
|
tensor.sharding_spec = target_spec
|
||||||
return tensor
|
return tensor
|
||||||
|
|
|
@ -1,28 +1,32 @@
|
||||||
|
import copy
|
||||||
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.fx import GraphModule
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import pytest
|
from torch.fx import GraphModule
|
||||||
from colossalai import device
|
|
||||||
from colossalai.initialize import launch
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from colossalai.logging import disable_existing_loggers
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
|
||||||
from copy import deepcopy
|
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
|
||||||
from torchvision.models import resnet34, resnet50
|
from torchvision.models import resnet34, resnet50
|
||||||
|
|
||||||
|
from colossalai import device
|
||||||
from colossalai.auto_parallel.tensor_shard.constants import *
|
from colossalai.auto_parallel.tensor_shard.constants import *
|
||||||
from colossalai.testing import assert_close_loose, assert_close
|
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||||
|
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||||
|
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||||
|
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||||
|
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
|
||||||
|
shape_consistency_pass,
|
||||||
|
solution_annotatation_pass,
|
||||||
|
)
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
seed = 128
|
seed = 128
|
||||||
cudnn_benchmark = False
|
cudnn_benchmark = False
|
||||||
|
@ -108,16 +112,17 @@ class Bottleneck(nn.Module):
|
||||||
def check_apply_bottleneck(rank, world_size, port):
|
def check_apply_bottleneck(rank, world_size, port):
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
input = torch.rand(256, 64, 64, 64).cuda()
|
input = torch.rand(4, 4, 4, 4).cuda()
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
# [[0, 1]
|
# [[0, 1]
|
||||||
# [2, 3]]
|
# [2, 3]]
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
entire_shape = torch.Size((4, 4, 8, 8))
|
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
model = Bottleneck(64, 64, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
|
model = Bottleneck(4, 4, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
|
||||||
|
test_model = copy.deepcopy(model)
|
||||||
|
test_input = copy.deepcopy(input)
|
||||||
# graph():
|
# graph():
|
||||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||||
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
|
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
|
||||||
|
@ -130,9 +135,8 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||||
# %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
|
# %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
|
||||||
# %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
|
# %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
|
||||||
# return relu_2
|
# return relu_2
|
||||||
input_sample = {'x': torch.rand(256, 64, 224, 224).to('meta')}
|
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
|
||||||
cuda_rng_state = torch.cuda.get_rng_state()
|
|
||||||
origin_output = model(input)
|
|
||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
@ -147,16 +151,42 @@ def check_apply_bottleneck(rank, world_size, port):
|
||||||
ret = solver.call_solver_serialized_args()
|
ret = solver.call_solver_serialized_args()
|
||||||
solution = list(ret[0])
|
solution = list(ret[0])
|
||||||
print(solution)
|
print(solution)
|
||||||
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
|
for index, node in enumerate(graph.nodes):
|
||||||
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
print(node.name, node.strategies_vector[solution[index]].name)
|
||||||
|
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||||
shape_consistency_pass(gm)
|
shape_consistency_pass(gm)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
nodes = [node for node in gm.graph.nodes]
|
nodes = [node for node in gm.graph.nodes]
|
||||||
# TODO: wrap the gm to avoid the influence of the user training code
|
# TODO: wrap the gm to avoid the influence of the user training code
|
||||||
|
cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
|
origin_output = test_model(test_input)
|
||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
output = gm(input, sharding_spec_dict, origin_spec_dict)
|
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||||
|
|
||||||
assert output.shape == origin_output.shape
|
assert output.shape == origin_output.shape
|
||||||
assert output.equal(origin_output)
|
assert_close(output, origin_output)
|
||||||
|
print("*******************backward starting*******************")
|
||||||
|
cuda_rng_state = torch.cuda.get_rng_state()
|
||||||
|
output.sum().backward()
|
||||||
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
|
origin_output.sum().backward()
|
||||||
|
if rank == 0:
|
||||||
|
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum())
|
||||||
|
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
|
||||||
|
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
||||||
|
if rank == 1:
|
||||||
|
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum())
|
||||||
|
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
|
||||||
|
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
||||||
|
if rank == 2:
|
||||||
|
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum())
|
||||||
|
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
|
||||||
|
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
||||||
|
|
||||||
|
if rank == 3:
|
||||||
|
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum())
|
||||||
|
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
|
||||||
|
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
||||||
|
|
||||||
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
import torch
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.initialize import launch
|
|
||||||
from colossalai.utils import free_port
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import CommSpec, CollectiveCommPattern
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
def check_all_gather(device_mesh, rank):
|
def check_all_gather(device_mesh, rank):
|
||||||
|
@ -37,7 +39,7 @@ def check_all_gather(device_mesh, rank):
|
||||||
sharding_spec,
|
sharding_spec,
|
||||||
gather_dim=1,
|
gather_dim=1,
|
||||||
logical_process_axis=1)
|
logical_process_axis=1)
|
||||||
comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
|
sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
|
||||||
|
|
||||||
assert sharded_tensor_to_comm.equal(tensor_to_check)
|
assert sharded_tensor_to_comm.equal(tensor_to_check)
|
||||||
|
|
||||||
|
@ -60,7 +62,7 @@ def check_shard(device_mesh, rank):
|
||||||
|
|
||||||
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
|
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
|
||||||
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
|
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
|
||||||
comm_spec.covert_spec_to_action(tensor_to_shard)
|
tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)
|
||||||
|
|
||||||
if rank in (0, 2):
|
if rank in (0, 2):
|
||||||
assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
|
assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
|
||||||
|
@ -110,7 +112,7 @@ def check_all_to_all(device_mesh, rank):
|
||||||
gather_dim=0,
|
gather_dim=0,
|
||||||
shard_dim=1,
|
shard_dim=1,
|
||||||
logical_process_axis=0)
|
logical_process_axis=0)
|
||||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||||
|
|
||||||
assert tensor_to_comm.equal(tensor_to_check)
|
assert tensor_to_comm.equal(tensor_to_check)
|
||||||
|
|
||||||
|
@ -137,7 +139,7 @@ def check_all_reduce_fwd(device_mesh, rank):
|
||||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
|
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
|
||||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||||
|
|
||||||
assert tensor_to_comm.equal(tensor_to_check)
|
assert tensor_to_comm.equal(tensor_to_check)
|
||||||
|
|
||||||
|
@ -155,7 +157,7 @@ def check_all_reduce_bwd(device_mesh, rank):
|
||||||
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
|
comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
|
||||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||||
|
|
||||||
assert tensor_to_comm.equal(tensor_to_check)
|
assert tensor_to_comm.equal(tensor_to_check)
|
||||||
|
|
||||||
|
@ -178,7 +180,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
|
||||||
|
|
||||||
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
|
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
|
||||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
|
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
|
||||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||||
|
|
||||||
assert tensor_to_comm.equal(tensor_to_check)
|
assert tensor_to_comm.equal(tensor_to_check)
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.utils import free_port
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
|
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
def check_apply(rank, world_size, port):
|
def check_apply(rank, world_size, port):
|
||||||
|
@ -63,7 +64,7 @@ def check_apply(rank, world_size, port):
|
||||||
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
|
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
|
||||||
|
|
||||||
tensor_to_comm.sharding_spec = sharding_spec_source
|
tensor_to_comm.sharding_spec = sharding_spec_source
|
||||||
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
|
tensor_to_comm = shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
|
||||||
assert tensor_to_comm.equal(tensor_to_check)
|
assert tensor_to_comm.equal(tensor_to_check)
|
||||||
assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)
|
assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue