mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add runtime pass and numerical test for view handler (#2018)
parent
bb6245612d
commit
ea0f6b8df9
|
@ -37,6 +37,30 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
|||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||
str(node))
|
||||
|
||||
# experimental pass for torch.Tensor.view
|
||||
# Arguments of view op will be divided in the sharded dimensions.
|
||||
for node in nodes:
|
||||
if node.op == 'call_method' and getattr(node.args[0]._meta_data.__class__, node.target) in (torch.Tensor.view,):
|
||||
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
|
||||
device_mesh = node.sharding_spec.device_mesh
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, int):
|
||||
new_args.append(arg._meta_data)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
|
||||
new_args.append(arg)
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
# the dict to record comm actions of nodes
|
||||
|
|
|
@ -103,13 +103,18 @@ class ViewGenerator(FollowingStrategyGenerator):
|
|||
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
|
||||
if len(total_mesh_dim_list) == 1:
|
||||
total_mesh_dim_list = total_mesh_dim_list[0]
|
||||
# the total mesh dim list only has one element, so the shard dim has only one element as well.
|
||||
shard_dim = list(dim_partition_dict_for_input.keys())[0]
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
logical_process_axis=total_mesh_dim_list,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
|
||||
# it will gather the input through gather_dim during forward phase.
|
||||
input_comm_action.comm_spec.gather_dim = shard_dim
|
||||
# it will split the input activation grad through shard_dim during backward phase.
|
||||
input_comm_action.comm_spec.shard_dim = shard_dim
|
||||
|
||||
elif len(total_mesh_dim_list) >= 2:
|
||||
source_spec = sharding_spec_mapping["input"]
|
||||
|
|
|
@ -105,6 +105,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
|||
dim_mapping={0: i},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
strategy_copy.name = f'{strategy.name}_{i}'
|
||||
sharding_strategies.append(strategy_copy)
|
||||
except ShardingNotDivisibleError as e:
|
||||
logger.debug(
|
||||
|
@ -194,7 +195,7 @@ class LinearModuleHandler(ModuleHandler):
|
|||
@operator_registry.register(F.linear)
|
||||
class LinearFunctionHandler(NodeHandler):
|
||||
"""
|
||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
|
|
|
@ -1,55 +1,130 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
|
||||
|
||||
|
||||
class ViewModel(nn.Module):
|
||||
class ConvViewModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, tgt_shape):
|
||||
super().__init__()
|
||||
self.tgt_shape = tgt_shape
|
||||
|
||||
def forward(self, input, other):
|
||||
conv_node = nn.functional.conv2d(input, other)
|
||||
reshape_node = conv_node.view(32, 4, 32, 32, 4)
|
||||
conv_node = nn.functional.conv2d(input, other, bias=None)
|
||||
reshape_node = conv_node.view(*self.tgt_shape)
|
||||
return reshape_node
|
||||
|
||||
|
||||
def test_view_handler():
|
||||
model = ViewModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %other : torch.Tensor [#users=1] = placeholder[target=other]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
|
||||
# return view
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 8, 66, 66).to('meta'),
|
||||
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
||||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
class LinearViewModel(nn.Module):
|
||||
|
||||
def __init__(self, tgt_shape):
|
||||
super().__init__()
|
||||
self.tgt_shape = tgt_shape
|
||||
|
||||
def forward(self, input, other):
|
||||
linear_node = nn.functional.linear(input, other, bias=None)
|
||||
reshape_node = linear_node.view(*self.tgt_shape)
|
||||
return reshape_node
|
||||
|
||||
|
||||
def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
model = model_cls(tgt_shape).cuda()
|
||||
|
||||
if model_cls.__name__ == 'ConvViewModel':
|
||||
input = torch.rand(8, 8, 66, 66).to('cuda')
|
||||
other = torch.rand(16, 8, 3, 3).to('cuda')
|
||||
# index of conv node in computation graph
|
||||
node_index = 2
|
||||
# total number of conv strategies
|
||||
strategy_number = 16
|
||||
if model_cls.__name__ == 'LinearViewModel':
|
||||
input = torch.rand(8, 16, 64, 32).to('cuda')
|
||||
other = torch.rand(64, 32).to('cuda')
|
||||
# index of linear node in computation graph
|
||||
node_index = 2
|
||||
# total number of linear strategies
|
||||
strategy_number = 23
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
conv_mod_node = list(graph.nodes)[2]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
node_index=node_index,
|
||||
strategy_number=strategy_number,
|
||||
input_args=[input, other],
|
||||
meta_arg_names=['input', 'other'],
|
||||
node_type='following')
|
||||
tracer = ColoTracer()
|
||||
if model_cls.__name__ == 'ConvViewModel':
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %other : torch.Tensor [#users=1] = placeholder[target=other]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
|
||||
# return view
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 16, 66, 66).to('meta'),
|
||||
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
||||
})
|
||||
|
||||
if model_cls.__name__ == 'LinearViewModel':
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %other : torch.Tensor [#users=1] = placeholder[target=other]
|
||||
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
|
||||
# return view
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 16, 64, 32).to('meta'),
|
||||
"other": torch.rand(64, 32).to('meta'),
|
||||
})
|
||||
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
previous_mod_node = list(graph.nodes)[2]
|
||||
view_node = list(graph.nodes)[3]
|
||||
view_strategies_vector = StrategiesVector(view_node)
|
||||
conv_strategies_vector = StrategiesVector(conv_mod_node)
|
||||
previous_strategies_vector = StrategiesVector(previous_mod_node)
|
||||
|
||||
# build handler
|
||||
conv_handler = ConvFunctionHandler(node=conv_mod_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=conv_strategies_vector)
|
||||
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
||||
if model_cls.__name__ == 'ConvViewModel':
|
||||
|
||||
conv_handler = ConvFunctionHandler(node=previous_mod_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=previous_strategies_vector)
|
||||
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector)
|
||||
|
||||
if model_cls.__name__ == 'LinearViewModel':
|
||||
assert len(previous_strategies_vector) == 0
|
||||
linear_handler = LinearFunctionHandler(node=previous_mod_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=previous_strategies_vector)
|
||||
linear_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector)
|
||||
|
||||
view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector)
|
||||
|
||||
view_handler.register_strategy(compute_resharding_cost=False)
|
||||
|
@ -62,7 +137,10 @@ def test_view_handler():
|
|||
# make sure they have valid values
|
||||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "conv2d"
|
||||
if model_cls.__name__ == 'ConvViewModel':
|
||||
assert mapping['input'].name == "conv2d"
|
||||
else:
|
||||
assert mapping['input'].name == "linear"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
|
@ -70,28 +148,117 @@ def test_view_handler():
|
|||
|
||||
assert mapping['output'].name == "view"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([32, 4, 32, 32, 4])
|
||||
assert mapping['output'].data.shape == torch.Size(tgt_shape)
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
|
||||
assert len(view_strategies_vector) == len(conv_strategies_vector)
|
||||
assert len(view_strategies_vector) == len(previous_strategies_vector)
|
||||
strategy_name_list = [strategy.name for strategy in view_strategies_vector]
|
||||
assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list
|
||||
assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list
|
||||
|
||||
if model_cls.__name__ == 'ConvViewModel':
|
||||
|
||||
if tgt_shape == (32, 4, 64, 16, 4):
|
||||
assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list
|
||||
assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list
|
||||
|
||||
if tgt_shape == (8, 4, 4, 64, 16, 4):
|
||||
assert '[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0' in strategy_name_list
|
||||
assert '[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_2' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_3' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_4' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_5' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_6' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R, R]_9' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_10' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_11' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R, R]_12' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_15' in strategy_name_list
|
||||
|
||||
if model_cls.__name__ == 'LinearViewModel':
|
||||
|
||||
if tgt_shape == (32, 4, 64, 16, 4):
|
||||
assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_0' in strategy_name_list
|
||||
assert '[R, S0, R, S1] -> FULLY REPLICATED_1' in strategy_name_list
|
||||
assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_2' in strategy_name_list
|
||||
assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_3' in strategy_name_list
|
||||
assert '[R, S1, R, S0] -> FULLY REPLICATED_4' in strategy_name_list
|
||||
assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_5' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R]_6' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list
|
||||
assert '[R, R, S0, R] -> [R, R, S0, R, R]_8' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R]_9' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> FULLY REPLICATED_10' in strategy_name_list
|
||||
assert '[R, R, S1, R] -> [R, R, S1, R, R]_11' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1, R]_12' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0, R]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_15' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, S0, R]_16' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, S1, R]_17' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R, R, R]_18' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> FULLY REPLICATED_19' in strategy_name_list
|
||||
assert '[R, R, S01, R] -> [R, R, S01, R, R]_20' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_21' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, R, R, S01, R]_22' in strategy_name_list
|
||||
|
||||
if tgt_shape == (8, 4, 4, 64, 16, 4):
|
||||
assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_0' in strategy_name_list
|
||||
assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1' in strategy_name_list
|
||||
assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_2' in strategy_name_list
|
||||
assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_3' in strategy_name_list
|
||||
assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_4' in strategy_name_list
|
||||
assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_5' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_6' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list
|
||||
assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_8' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_9' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_10' in strategy_name_list
|
||||
assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_11' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_12' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R, R]_15' in strategy_name_list
|
||||
assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_16' in strategy_name_list
|
||||
assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_17' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_18' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_19' in strategy_name_list
|
||||
assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_20' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R, R]_21' in strategy_name_list
|
||||
assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_22' in strategy_name_list
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@parameterize('tgt_shape', [(32, 4, 64, 16, 4), (8, 4, 4, 64, 16, 4)])
|
||||
@parameterize('model_cls', [ConvViewModel, LinearViewModel])
|
||||
def test_view_handler(tgt_shape, model_cls):
|
||||
world_size = 4
|
||||
run_func = partial(check_view_handler,
|
||||
tgt_shape=tgt_shape,
|
||||
model_cls=model_cls,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -87,6 +87,11 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
solution_len = len(strategies_constructor.leaf_strategies)
|
||||
solution = [0] * solution_len
|
||||
solution[node_index] = strategy_index
|
||||
elif node_type == 'following':
|
||||
solution_len = len(strategies_constructor.leaf_strategies)
|
||||
solution = [0] * solution_len
|
||||
solution[node_index] = strategy_index
|
||||
solution[node_index + 1] = strategy_index
|
||||
else:
|
||||
node_vector = strategies_constructor.leaf_strategies[node_index]
|
||||
strategy_to_keep = node_vector[strategy_index]
|
||||
|
@ -121,7 +126,6 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
grad_to_shard = grad_to_shard_dict[key]
|
||||
grad_to_compare = grad_to_compare_dict[key]
|
||||
assert_close_helper(grad_to_shard, grad_to_compare, strategy_index=strategy_index, type='input grad')
|
||||
|
||||
# extract the strategy used in this iter
|
||||
strategy_in_use = target_node.strategies_vector[strategy_index]
|
||||
param_to_shard_dict = dict(gm.named_parameters())
|
||||
|
|
Loading…
Reference in New Issue