mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] shard param and buffer as expected (#1753)
* [autoparallel] shard param and buffer as expected * fix unit test issuepull/1751/head
parent
cdb7d5e7d2
commit
980ed21723
|
@ -1,4 +1,5 @@
|
|||
import builtins
|
||||
import copy
|
||||
import operator
|
||||
from ast import NodeTransformer
|
||||
from copy import deepcopy
|
||||
|
@ -11,34 +12,13 @@ from torch.fx.node import Node
|
|||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction, CommType, OperationDataType
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.split_module import split_module
|
||||
from colossalai.tensor.comm_spec import CommSpec, _all_reduce
|
||||
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec, _all_reduce, pattern_to_func_dict
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
|
||||
class ConsistencyApply(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, node, origin_sharding_spec, target_sharding_spec):
|
||||
ctx.origin_sharding_spec = origin_sharding_spec
|
||||
ctx.target_sharding_spec = target_sharding_spec
|
||||
return shape_consistency_manager.apply_for_autoparallel_runtime(node, ctx.origin_sharding_spec,
|
||||
ctx.target_sharding_spec)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, node_grad):
|
||||
return shape_consistency_manager.apply_for_autoparallel_runtime(
|
||||
node_grad, ctx.target_sharding_spec, ctx.origin_sharding_spec), None, None, None, None
|
||||
|
||||
|
||||
def runtime_apply_for_leaf_node(node, origin_dict, input_dict, node_index, user_node_index):
|
||||
origin_sharding_spec = origin_dict[node_index]
|
||||
target_sharding_spec = input_dict[node_index][user_node_index]
|
||||
return ConsistencyApply.apply(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index):
|
||||
origin_sharding_spec = origin_dict[node_index]
|
||||
target_sharding_spec = input_dict[node_index][user_node_index]
|
||||
|
@ -53,7 +33,7 @@ def runtime_comm_spec_apply(tensor, comm_actions_dict, node_index, op_data):
|
|||
else:
|
||||
origin_sharding_spec = comm_action.comm_spec['src_spec']
|
||||
tgt_sharding_spec = comm_action.comm_spec['tgt_spec']
|
||||
rst = ConsistencyApply.apply(tensor, origin_sharding_spec, tgt_sharding_spec)
|
||||
rst = shape_consistency_manager.apply_for_autoparallel_runtime(tensor, origin_sharding_spec, tgt_sharding_spec)
|
||||
return rst
|
||||
|
||||
|
||||
|
@ -75,12 +55,17 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
|
|||
if node.op == 'call_module':
|
||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||
for name, param in target_module.named_parameters():
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
shape_consistency_manager.apply(param, target_sharding_spec)
|
||||
if target_sharding_spec.dim_partition_dict != {}:
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||
param_sharded = torch.nn.Parameter(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||
target_sharding_spec).detach().clone())
|
||||
else:
|
||||
param_sharded = param
|
||||
setattr(target_module, name, param_sharded)
|
||||
comm_actions = node.best_strategy.communication_actions
|
||||
|
||||
for operation_data, comm_action in comm_actions.items():
|
||||
comm_spec_to_use = comm_action.comm_spec
|
||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
||||
|
@ -88,13 +73,18 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
|
|||
def hook_fn(grad):
|
||||
_all_reduce(grad, comm_spec_to_use)
|
||||
|
||||
param.register_hook(hook_fn)
|
||||
param_sharded.register_hook(hook_fn)
|
||||
|
||||
sharded_buffer_dict = {}
|
||||
for name, buffer in target_module.named_buffers():
|
||||
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
|
||||
setattr(buffer, 'sharding_spec', origin_sharding_spec)
|
||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||
shape_consistency_manager.apply(buffer, target_sharding_spec)
|
||||
buffer_sharded = shape_consistency_manager.apply(buffer, target_sharding_spec)
|
||||
sharded_buffer_dict[name] = buffer_sharded
|
||||
|
||||
for name, buffer_sharded in sharded_buffer_dict.items():
|
||||
setattr(target_module, name, buffer_sharded.detach().clone())
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
|
@ -157,19 +147,11 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
|
|||
|
||||
for user_node in node.strategies_vector.successor_nodes:
|
||||
user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
|
||||
if user_node.op != "output":
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
else:
|
||||
# we need to call an autograd.Function for leaf node
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply_for_leaf_node,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
|
||||
origin_index_args = user_node.args.index(node)
|
||||
new_args = list(user_node.args)
|
||||
|
@ -179,21 +161,29 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
|
|||
comm_actions = node.best_strategy.communication_actions
|
||||
for op_data, comm_action in comm_actions.items():
|
||||
comm_object = node.args[comm_action.arg_index]
|
||||
if op_data.type == OperationDataType.ARG:
|
||||
if comm_action.comm_type == CommType.BEFORE:
|
||||
with mod_graph.inserting_before(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(comm_object, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
elif comm_action.comm_type == CommType.AFTER:
|
||||
with mod_graph.inserting_after(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(comm_object, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
continue
|
||||
if comm_action.comm_type == CommType.BEFORE:
|
||||
with mod_graph.inserting_before(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(comm_object, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
new_args = list(node.args)
|
||||
new_args[comm_action.arg_index] = comm_spec_apply_node
|
||||
node.args = new_args
|
||||
elif comm_action.comm_type == CommType.AFTER:
|
||||
with mod_graph.inserting_after(node):
|
||||
comm_spec_apply_node = mod_graph.create_node('call_function',
|
||||
runtime_comm_spec_apply,
|
||||
args=(node, comm_actions_dict_node,
|
||||
node_to_index_dict[node], op_data.name))
|
||||
user_list = list(node.users.keys())
|
||||
for user in user_list:
|
||||
if user == comm_spec_apply_node:
|
||||
continue
|
||||
new_args = list(user.args)
|
||||
new_args[new_args.index(node)] = comm_spec_apply_node
|
||||
user.args = tuple(new_args)
|
||||
# TODO: consider other OperationDataType, such as OperationDataType.OUTPUT
|
||||
new_args = list(node.args)
|
||||
new_args[comm_action.arg_index] = comm_spec_apply_node
|
||||
node.args = new_args
|
||||
return gm
|
||||
|
|
|
@ -345,9 +345,9 @@ class CommSpec:
|
|||
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
|
||||
'''
|
||||
if self.comm_pattern in pattern_to_func_dict:
|
||||
tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self)
|
||||
tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
|
||||
else:
|
||||
tensor.data = tensor
|
||||
tensor = tensor
|
||||
return tensor
|
||||
|
||||
|
||||
|
|
|
@ -511,13 +511,13 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
|||
'''
|
||||
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
|
||||
for comm_spec in comm_action_sequence:
|
||||
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
|
||||
tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
|
||||
tensor_with_sharding_spec.sharding_spec = target_spec
|
||||
return tensor_with_sharding_spec
|
||||
|
||||
def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec):
|
||||
_, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec)
|
||||
for comm_spec in comm_action_sequence:
|
||||
comm_spec.covert_spec_to_action(tensor)
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
tensor.sharding_spec = target_spec
|
||||
return tensor
|
||||
|
|
|
@ -1,28 +1,32 @@
|
|||
import copy
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.fx import GraphModule
|
||||
import torch.nn as nn
|
||||
import pytest
|
||||
from colossalai import device
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import shape_consistency_pass, solution_annotatation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from copy import deepcopy
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from torch.fx import GraphModule
|
||||
from torchvision.models import resnet34, resnet50
|
||||
|
||||
from colossalai import device
|
||||
from colossalai.auto_parallel.tensor_shard.constants import *
|
||||
from colossalai.testing import assert_close_loose, assert_close
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2 import (
|
||||
shape_consistency_pass,
|
||||
solution_annotatation_pass,
|
||||
)
|
||||
from colossalai.fx.tracer.tracer import ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
|
||||
seed = 128
|
||||
cudnn_benchmark = False
|
||||
|
@ -108,16 +112,17 @@ class Bottleneck(nn.Module):
|
|||
def check_apply_bottleneck(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
input = torch.rand(256, 64, 64, 64).cuda()
|
||||
input = torch.rand(4, 4, 4, 4).cuda()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
# [[0, 1]
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=False)
|
||||
entire_shape = torch.Size((4, 4, 8, 8))
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
tracer = ColoTracer()
|
||||
model = Bottleneck(64, 64, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
|
||||
model = Bottleneck(4, 4, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
|
||||
test_model = copy.deepcopy(model)
|
||||
test_input = copy.deepcopy(input)
|
||||
# graph():
|
||||
# %x : torch.Tensor [#users=1] = placeholder[target=x]
|
||||
# %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
|
||||
|
@ -130,9 +135,8 @@ def check_apply_bottleneck(rank, world_size, port):
|
|||
# %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
|
||||
# %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
|
||||
# return relu_2
|
||||
input_sample = {'x': torch.rand(256, 64, 224, 224).to('meta')}
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
origin_output = model(input)
|
||||
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
|
||||
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
@ -147,16 +151,42 @@ def check_apply_bottleneck(rank, world_size, port):
|
|||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
print(solution)
|
||||
device_mesh.process_groups_dict = device_mesh.create_process_groups_for_logical_mesh()
|
||||
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
for index, node in enumerate(graph.nodes):
|
||||
print(node.name, node.strategies_vector[solution[index]].name)
|
||||
sharding_spec_dict, origin_spec_dict, comm_actions_dict = solution_annotatation_pass(gm, solution, device_mesh)
|
||||
shape_consistency_pass(gm)
|
||||
gm.recompile()
|
||||
nodes = [node for node in gm.graph.nodes]
|
||||
# TODO: wrap the gm to avoid the influence of the user training code
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
origin_output = test_model(test_input)
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict)
|
||||
output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
||||
|
||||
assert output.shape == origin_output.shape
|
||||
assert output.equal(origin_output)
|
||||
assert_close(output, origin_output)
|
||||
print("*******************backward starting*******************")
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
output.sum().backward()
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
origin_output.sum().backward()
|
||||
if rank == 0:
|
||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum())
|
||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
|
||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
||||
if rank == 1:
|
||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum())
|
||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum())
|
||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
|
||||
if rank == 2:
|
||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum())
|
||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
|
||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
||||
|
||||
if rank == 3:
|
||||
print((gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum())
|
||||
print((gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum())
|
||||
assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
import torch
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CommSpec, CollectiveCommPattern
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
def check_all_gather(device_mesh, rank):
|
||||
|
@ -37,7 +39,7 @@ def check_all_gather(device_mesh, rank):
|
|||
sharding_spec,
|
||||
gather_dim=1,
|
||||
logical_process_axis=1)
|
||||
comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
|
||||
sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
|
||||
|
||||
assert sharded_tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
@ -60,7 +62,7 @@ def check_shard(device_mesh, rank):
|
|||
|
||||
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
|
||||
comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
|
||||
comm_spec.covert_spec_to_action(tensor_to_shard)
|
||||
tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)
|
||||
|
||||
if rank in (0, 2):
|
||||
assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
|
||||
|
@ -110,7 +112,7 @@ def check_all_to_all(device_mesh, rank):
|
|||
gather_dim=0,
|
||||
shard_dim=1,
|
||||
logical_process_axis=0)
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
@ -137,7 +139,7 @@ def check_all_reduce_fwd(device_mesh, rank):
|
|||
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
@ -155,7 +157,7 @@ def check_all_reduce_bwd(device_mesh, rank):
|
|||
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
|
||||
|
||||
comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
@ -178,7 +180,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
|
|||
|
||||
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
|
||||
comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
|
||||
|
||||
def check_apply(rank, world_size, port):
|
||||
|
@ -63,7 +64,7 @@ def check_apply(rank, world_size, port):
|
|||
tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
|
||||
|
||||
tensor_to_comm.sharding_spec = sharding_spec_source
|
||||
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
|
||||
tensor_to_comm = shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
assert str(tensor_to_comm.sharding_spec.sharding_sequence) == str(sharding_spec_target.sharding_sequence)
|
||||
|
||||
|
|
Loading…
Reference in New Issue