[autoparallel] add split handler (#2032)

* [autoparallel] add split handler

* add numerical test and runtime passes
pull/2035/head
YuliangLiu0306 2022-11-29 11:03:51 +08:00 committed by GitHub
parent 28aa9a4294
commit 0dbcd4a6f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 500 additions and 22 deletions

View File

@ -13,6 +13,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
@ -27,6 +28,23 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
user_node_index: int):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
for index, (origin_sharding_spec,
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
input_dict[node_index][user_node_index])):
rst.append(
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
target_sharding_spec))
rst = type(node)(rst)
return rst
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
@ -81,13 +99,34 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
node.target_sharding_specs,
(list,
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
total_difference = 0
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
node.target_sharding_specs[user_node_index]):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply_for_iterable_object,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
else:
assert isinstance(node.sharding_spec,
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node

View File

@ -100,8 +100,24 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
continue
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
device_mesh = node.sharding_spec.device_mesh
def _process_sharding_spec(sharding_spec):
if isinstance(sharding_spec, ShardingSpec):
dim_partition_dict = sharding_spec.dim_partition_dict
device_mesh = sharding_spec.device_mesh
return dim_partition_dict, device_mesh
if sharding_spec is None:
return None, None
assert isinstance(sharding_spec,
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
device_mesh = sharding_spec[0].device_mesh
dim_partition_dict = []
for element in sharding_spec:
dim_partition_dict.append(_process_sharding_spec(element))
return dim_partition_dict, sharding_spec
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
new_args = []
if node.op == 'call_method':

View File

@ -1,8 +1,10 @@
from .permute_handler import PermuteHandler
from .reshape_generator import PermuteGenerator, TransposeGenerator, ViewGenerator
from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator
from .split_handler import SplitHandler
from .transpose_handler import TransposeHandler
from .view_handler import ViewHandler
__all__ = [
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator'
'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator',
'SplitHandler', 'SplitGenerator'
]

View File

@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator']
__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
@ -227,3 +227,73 @@ class TransposeGenerator(ReshapeGenerator):
strategy_list.append(strategy)
return strategy_list
class SplitGenerator(ReshapeGenerator):
"""
SplitGenerator deals with the sharding strategies of split op.
"""
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
recover_dims = None
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
split_size, split_dim = self.op_data['split_info'].data
if split_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(split_dim)
dim_partition_dict_for_output = [
copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
]
assert len(dim_partition_dict_for_output) >= 2
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
# add comm action if the input need to be recovered to replica in the split dimension.
if recover_dims:
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(recover_dims) == 1:
recover_dims = recover_dims[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=recover_dims,
comm_type=CommType.BEFORE,
arg_index=0)
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = split_dim
# it will split the input activation grad through split_dim during backward phase.
input_comm_action.comm_spec.shard_dim = split_dim
elif len(recover_dims) >= 2:
# original sharding spec
source_spec = input_sharding_spec
# target sharding spec
target_spec = sharding_spec_mapping["input"]
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list

View File

@ -0,0 +1,63 @@
from typing import Dict, List
import torch
from ...sharding_strategy import OperationData, OperationDataType
from ..node_handler import NodeHandler
from ..registry import operator_registry
from ..strategy import StrategyGenerator
from .reshape_generator import SplitGenerator
__all__ = ['SplitHandler']
@operator_registry.register(torch.Tensor.split)
@operator_registry.register(torch.split)
class SplitHandler(NodeHandler):
"""
A SplitHandler which deals with the sharding strategies for torch.permute or torch.split.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SplitGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
split_size = self.node.args[1]
if len(self.node.args) == 3:
# (input, split_size, split_dim)
split_dim = self.node.args[2]
else:
if self.node.kwargs:
split_dim = self.node.kwargs['dim']
else:
split_dim = 0
num_dims = self.node.args[0]._meta_data.dim()
# recover negative value to positive
if split_dim < 0:
split_dim += num_dims
split_info = (split_size, split_dim)
physical_shape_operand = OperationData(name='split_info', type=OperationDataType.ARG, data=split_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"split_info": physical_shape_operand,
"output": physical_output_operand
}
return mapping

View File

@ -10,8 +10,6 @@ from .strategy import ReshapeGenerator, StrategyGenerator
__all__ = ['ReshapeHandler']
@operator_registry.register(torch.Tensor.split)
@operator_registry.register(torch.split)
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):

View File

@ -49,12 +49,23 @@ class OutputGenerator(OutputStrategyGenerator):
"""
Generate replica strategy for output node.
"""
dim_partition_dict_mapping = {
"output": {},
}
dim_partition_dict_mapping = {}
dim_partition_dict_for_output = []
for index, _ in enumerate(self.predecessor_nodes):
mapping_name = f"input_{index}"
dim_partition_dict_mapping[mapping_name] = {}
if isinstance(self.op_data[mapping_name].data, (tuple, list)):
dim_partition_dict_for_input = [{} for _ in range(len(self.op_data[mapping_name].data))]
else:
dim_partition_dict_for_input = {}
dim_partition_dict_mapping[mapping_name] = dim_partition_dict_for_input
dim_partition_dict_for_output.append(dim_partition_dict_for_input)
if len(dim_partition_dict_for_output) == 1:
dim_partition_dict_for_output = dim_partition_dict_for_output[0]
else:
dim_partition_dict_for_output = tuple(dim_partition_dict_for_output)
dim_partition_dict_mapping['output'] = dim_partition_dict_for_output
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)

View File

@ -0,0 +1,270 @@
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
class ConvSplitModel(nn.Module):
def __init__(self, split_size, split_dim):
super().__init__()
self.split_size = split_size
self.split_dim = split_dim
def forward(self, input, other):
conv_node = nn.functional.conv2d(input, other, bias=None)
split_node = conv_node.split(self.split_size, dim=self.split_dim)
return split_node
class LinearSplitModel(nn.Module):
def __init__(self, split_size, split_dim):
super().__init__()
self.split_size = split_size
self.split_dim = split_dim
def forward(self, input, other):
linear_node = nn.functional.linear(input, other, bias=None)
split_node = linear_node.split(self.split_size, dim=self.split_dim)
return split_node
def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = model_cls(split_size=split_size, split_dim=split_dim).cuda()
if model_cls.__name__ == 'ConvSplitModel':
input = torch.rand(8, 8, 66, 66).to('cuda')
other = torch.rand(16, 8, 3, 3).to('cuda')
# index of conv node in computation graph
node_index = 2
# total number of conv strategies
strategy_number = 16
if model_cls.__name__ == 'LinearSplitModel':
input = torch.rand(8, 16, 64, 32).to('cuda')
other = torch.rand(64, 32).to('cuda')
# index of linear node in computation graph
node_index = 2
# total number of linear strategies
strategy_number = 23
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=[input, other],
meta_arg_names=['input', 'other'],
node_type='following')
tracer = ColoTracer()
if model_cls.__name__ == 'ConvSplitModel':
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %split : [#users=1] = call_method[target=split](args = (%conv2d,), kwargs = {})
# return split
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 8, 66, 66).to('meta'),
"other": torch.rand(16, 8, 3, 3).to('meta'),
})
if model_cls.__name__ == 'LinearSplitModel':
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %split : [#users=1] = call_method[target=split](args = (%linear,), kwargs = {})
# return split
graph = tracer.trace(model,
meta_args={
"input": torch.rand(8, 16, 64, 32).to('meta'),
"other": torch.rand(64, 32).to('meta'),
})
gm = ColoGraphModule(model, graph)
previous_mod_node = list(graph.nodes)[2]
split_node = list(graph.nodes)[3]
split_strategies_vector = StrategiesVector(split_node)
previous_strategies_vector = StrategiesVector(previous_mod_node)
# build handler
if model_cls.__name__ == 'ConvSplitModel':
conv_handler = ConvFunctionHandler(node=previous_mod_node,
device_mesh=device_mesh,
strategies_vector=previous_strategies_vector)
conv_handler.register_strategy(compute_resharding_cost=False)
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector)
if model_cls.__name__ == 'LinearSplitModel':
assert len(previous_strategies_vector) == 0
linear_handler = LinearFunctionHandler(node=previous_mod_node,
device_mesh=device_mesh,
strategies_vector=previous_strategies_vector)
linear_handler.register_strategy(compute_resharding_cost=False)
setattr(previous_mod_node, 'strategies_vector', previous_strategies_vector)
split_handler = SplitHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector)
split_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
mapping = split_handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
# make sure they have valid values
assert op_data.data is not None
if model_cls.__name__ == 'ConvSplitModel':
assert mapping['input'].name == "conv2d"
else:
assert mapping['input'].name == "linear"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64])
assert mapping['output'].name == "split"
split_items = torch.empty([8, 16, 64, 64]).split(split_size, split_dim)
assert mapping['output'].logical_shape == tuple([item.shape for item in split_items])
assert mapping['output'].type == OperationDataType.OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(split_strategies_vector) == len(previous_strategies_vector)
strategy_name_list = [strategy.name for strategy in split_strategies_vector]
for name in strategy_name_list:
print(name)
if model_cls.__name__ == 'ConvSplitModel':
if split_dim == 0:
assert '[R, S1, R, R]_0' in strategy_name_list
assert '[R, S0, R, R]_1' in strategy_name_list
assert '[R, R, R, R]_2' in strategy_name_list
assert '[R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, R]_4' in strategy_name_list
assert '[R, R, R, R]_5' in strategy_name_list
assert '[R, S1, R, R]_6' in strategy_name_list
assert '[R, S0, R, R]_7' in strategy_name_list
assert '[R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R]_9' in strategy_name_list
assert '[R, S0, R, R]_10' in strategy_name_list
assert '[R, S1, R, R]_11' in strategy_name_list
assert '[R, R, R, R]_12' in strategy_name_list
assert '[R, R, R, R]_13' in strategy_name_list
assert '[R, R, R, R]_14' in strategy_name_list
assert '[R, S01, R, R]_15' in strategy_name_list
if split_dim == 1:
assert '[S0, R, R, R]_0' in strategy_name_list
assert '[S1, R, R, R]_1' in strategy_name_list
assert '[S0, R, R, R]_2' in strategy_name_list
assert '[S1, R, R, R]_3' in strategy_name_list
assert '[S0, R, R, R]_4' in strategy_name_list
assert '[S1, R, R, R]_5' in strategy_name_list
assert '[R, R, R, R]_6' in strategy_name_list
assert '[R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R]_9' in strategy_name_list
assert '[R, R, R, R]_10' in strategy_name_list
assert '[R, R, R, R]_11' in strategy_name_list
assert '[R, R, R, R]_12' in strategy_name_list
assert '[S01, R, R, R]_13' in strategy_name_list
assert '[R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R]_15' in strategy_name_list
if model_cls.__name__ == 'LinearSplitModel':
if split_dim == 0:
assert '[R, R, R, S1]_0' in strategy_name_list
assert '[R, S0, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1]_2' in strategy_name_list
assert '[R, R, R, S0]_3' in strategy_name_list
assert '[R, S1, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0]_5' in strategy_name_list
assert '[R, R, R, R]_6' in strategy_name_list
assert '[R, S0, R, R]_7' in strategy_name_list
assert '[R, R, S0, R]_8' in strategy_name_list
assert '[R, R, R, R]_9' in strategy_name_list
assert '[R, S1, R, R]_10' in strategy_name_list
assert '[R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1]_17' in strategy_name_list
assert '[R, R, R, R]_18' in strategy_name_list
assert '[R, S01, R, R]_19' in strategy_name_list
assert '[R, R, S01, R]_20' in strategy_name_list
assert '[R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01]_22' in strategy_name_list
if split_dim == 1:
assert '[S0, R, R, S1]_0' in strategy_name_list
assert '[R, R, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1]_2' in strategy_name_list
assert '[S1, R, R, S0]_3' in strategy_name_list
assert '[R, R, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0]_5' in strategy_name_list
assert '[S0, R, R, R]_6' in strategy_name_list
assert '[R, R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R]_9' in strategy_name_list
assert '[R, R, R, R]_10' in strategy_name_list
assert '[R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0]_13' in strategy_name_list
assert '[R, R, R, R]_14' in strategy_name_list
assert '[R, R, R, R]_15' in strategy_name_list
assert '[R, R, R, S0]_16' in strategy_name_list
assert '[R, R, R, S1]_17' in strategy_name_list
assert '[S01, R, R, R]_18' in strategy_name_list
assert '[R, R, R, R]_19' in strategy_name_list
assert '[R, R, S01, R]_20' in strategy_name_list
assert '[R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01]_22' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
@parameterize('split_size', [2])
@parameterize('split_dim', [0, 1, 2])
@parameterize('model_cls', [ConvSplitModel, LinearSplitModel])
def test_split_handler(split_size, split_dim, model_cls):
world_size = 4
run_func = partial(check_split_handler,
split_size=split_size,
split_dim=split_dim,
model_cls=model_cls,
world_size=world_size,
port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_split_handler()

View File

@ -118,10 +118,15 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
assert_close_helper(output, output_to_compare, strategy_index=strategy_index, type='forward output')
# backward result compare
loss = output.sum()
loss_to_compare = output_to_compare.sum()
loss.backward()
if isinstance(output, (tuple, list)):
loss = output[0].sum()
loss_to_compare = output_to_compare[0].sum()
else:
loss = output.sum()
loss_to_compare = output_to_compare.sum()
loss_to_compare.backward()
loss.backward()
for key in grad_to_shard_dict.keys():
grad_to_shard = grad_to_shard_dict[key]
grad_to_compare = grad_to_compare_dict[key]
@ -157,6 +162,10 @@ def assert_close_helper(first: torch.Tensor,
"""
# average_diff_tensor = ((first - second)/(second+0.1)).sum()/second.numel()
try:
assert_close(first, second, rtol=rtol, atol=atol)
if isinstance(first, (tuple, list)):
for first_element, second_element in zip(first, second):
assert_close(first_element, second_element, rtol=rtol, atol=atol)
else:
assert_close(first, second, rtol=rtol, atol=atol)
except:
print(f'strategy index {strategy_index} encounter assert_close error on {type}')