mirror of https://github.com/hpcaitech/ColossalAI
fix typo tests/ (#3936)
parent
bd2c7c3297
commit
e61ffc77c6
|
@ -2,7 +2,7 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
|
||||
__all__ = ['ModelZooRegistry', 'ModelAttribute', 'model_zoo']
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -37,7 +37,7 @@ class ModelZooRegistry(dict):
|
|||
>>> model_zoo = ModelZooRegistry()
|
||||
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
|
||||
>>> # Run the model
|
||||
>>> data = resnresnet18_data_gen() # do not input any argument
|
||||
>>> data = resnet18_data_gen() # do not input any argument
|
||||
>>> model = resnet18() # do not input any argument
|
||||
>>> out = model(**data)
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ def check_bn_module_handler(rank, world_size, port):
|
|||
# the index of bn node in computation graph
|
||||
node_index = 1
|
||||
# the total number of bn strategies without sync bn mode
|
||||
# TODO: add sync bn stategies after related passes ready
|
||||
# TODO: add sync bn strategies after related passes ready
|
||||
strategy_number = 4
|
||||
numerical_test_for_node_strategy(model=model,
|
||||
device_mesh=device_mesh,
|
||||
|
|
|
@ -43,14 +43,14 @@ def test_output_handler(output_option):
|
|||
output_strategies_vector = StrategiesVector(output_node)
|
||||
|
||||
# build handler
|
||||
otuput_handler = OutputHandler(node=output_node,
|
||||
output_handler = OutputHandler(node=output_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=output_strategies_vector,
|
||||
output_option=output_option)
|
||||
|
||||
otuput_handler.register_strategy(compute_resharding_cost=False)
|
||||
output_handler.register_strategy(compute_resharding_cost=False)
|
||||
# check operation data mapping
|
||||
mapping = otuput_handler.get_operation_data_mapping()
|
||||
mapping = output_handler.get_operation_data_mapping()
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
|
@ -59,7 +59,7 @@ def test_output_handler(output_option):
|
|||
|
||||
assert mapping['output'].name == "output"
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
|
||||
strategy_name_list = [val.name for val in output_handler.strategies_vector]
|
||||
if output_option == 'distributed':
|
||||
assert "Distributed Output" in strategy_name_list
|
||||
else:
|
||||
|
|
|
@ -137,7 +137,7 @@ def check_layout_converting(rank, world_size, port):
|
|||
assert comm_action_sequence[2].shard_dim == 0
|
||||
assert comm_action_sequence[2].logical_process_axis == 1
|
||||
|
||||
# checkout chached_spec_pairs_transform_path
|
||||
# checkout cached_spec_pairs_transform_path
|
||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
|
||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
|
||||
|
||||
|
|
Loading…
Reference in New Issue