fix typo tests/ (#3936)

pull/3946/head
digger yu 2023-06-09 09:49:41 +08:00 committed by GitHub
parent bd2c7c3297
commit e61ffc77c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 8 additions and 8 deletions

View File

@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import Callable
__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
__all__ = ['ModelZooRegistry', 'ModelAttribute', 'model_zoo']
@dataclass
@ -37,7 +37,7 @@ class ModelZooRegistry(dict):
>>> model_zoo = ModelZooRegistry()
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
>>> # Run the model
>>> data = resnresnet18_data_gen() # do not input any argument
>>> data = resnet18_data_gen() # do not input any argument
>>> model = resnet18() # do not input any argument
>>> out = model(**data)

View File

@ -27,7 +27,7 @@ def check_bn_module_handler(rank, world_size, port):
# the index of bn node in computation graph
node_index = 1
# the total number of bn strategies without sync bn mode
# TODO: add sync bn stategies after related passes ready
# TODO: add sync bn strategies after related passes ready
strategy_number = 4
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,

View File

@ -43,14 +43,14 @@ def test_output_handler(output_option):
output_strategies_vector = StrategiesVector(output_node)
# build handler
otuput_handler = OutputHandler(node=output_node,
output_handler = OutputHandler(node=output_node,
device_mesh=device_mesh,
strategies_vector=output_strategies_vector,
output_option=output_option)
otuput_handler.register_strategy(compute_resharding_cost=False)
output_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
mapping = otuput_handler.get_operation_data_mapping()
mapping = output_handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
@ -59,7 +59,7 @@ def test_output_handler(output_option):
assert mapping['output'].name == "output"
assert mapping['output'].type == OperationDataType.OUTPUT
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
strategy_name_list = [val.name for val in output_handler.strategies_vector]
if output_option == 'distributed':
assert "Distributed Output" in strategy_name_list
else:

View File

@ -137,7 +137,7 @@ def check_layout_converting(rank, world_size, port):
assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1
# checkout chached_spec_pairs_transform_path
# checkout cached_spec_pairs_transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence