mirror of https://github.com/hpcaitech/ColossalAI
fix typo tests/ (#3936)
parent
bd2c7c3297
commit
e61ffc77c6
|
@ -2,7 +2,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
|
__all__ = ['ModelZooRegistry', 'ModelAttribute', 'model_zoo']
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -37,7 +37,7 @@ class ModelZooRegistry(dict):
|
||||||
>>> model_zoo = ModelZooRegistry()
|
>>> model_zoo = ModelZooRegistry()
|
||||||
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
|
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
|
||||||
>>> # Run the model
|
>>> # Run the model
|
||||||
>>> data = resnresnet18_data_gen() # do not input any argument
|
>>> data = resnet18_data_gen() # do not input any argument
|
||||||
>>> model = resnet18() # do not input any argument
|
>>> model = resnet18() # do not input any argument
|
||||||
>>> out = model(**data)
|
>>> out = model(**data)
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ def check_bn_module_handler(rank, world_size, port):
|
||||||
# the index of bn node in computation graph
|
# the index of bn node in computation graph
|
||||||
node_index = 1
|
node_index = 1
|
||||||
# the total number of bn strategies without sync bn mode
|
# the total number of bn strategies without sync bn mode
|
||||||
# TODO: add sync bn stategies after related passes ready
|
# TODO: add sync bn strategies after related passes ready
|
||||||
strategy_number = 4
|
strategy_number = 4
|
||||||
numerical_test_for_node_strategy(model=model,
|
numerical_test_for_node_strategy(model=model,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
|
|
|
@ -43,14 +43,14 @@ def test_output_handler(output_option):
|
||||||
output_strategies_vector = StrategiesVector(output_node)
|
output_strategies_vector = StrategiesVector(output_node)
|
||||||
|
|
||||||
# build handler
|
# build handler
|
||||||
otuput_handler = OutputHandler(node=output_node,
|
output_handler = OutputHandler(node=output_node,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
strategies_vector=output_strategies_vector,
|
strategies_vector=output_strategies_vector,
|
||||||
output_option=output_option)
|
output_option=output_option)
|
||||||
|
|
||||||
otuput_handler.register_strategy(compute_resharding_cost=False)
|
output_handler.register_strategy(compute_resharding_cost=False)
|
||||||
# check operation data mapping
|
# check operation data mapping
|
||||||
mapping = otuput_handler.get_operation_data_mapping()
|
mapping = output_handler.get_operation_data_mapping()
|
||||||
|
|
||||||
for name, op_data in mapping.items():
|
for name, op_data in mapping.items():
|
||||||
op_data: OperationData
|
op_data: OperationData
|
||||||
|
@ -59,7 +59,7 @@ def test_output_handler(output_option):
|
||||||
|
|
||||||
assert mapping['output'].name == "output"
|
assert mapping['output'].name == "output"
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
|
strategy_name_list = [val.name for val in output_handler.strategies_vector]
|
||||||
if output_option == 'distributed':
|
if output_option == 'distributed':
|
||||||
assert "Distributed Output" in strategy_name_list
|
assert "Distributed Output" in strategy_name_list
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -137,7 +137,7 @@ def check_layout_converting(rank, world_size, port):
|
||||||
assert comm_action_sequence[2].shard_dim == 0
|
assert comm_action_sequence[2].shard_dim == 0
|
||||||
assert comm_action_sequence[2].logical_process_axis == 1
|
assert comm_action_sequence[2].logical_process_axis == 1
|
||||||
|
|
||||||
# checkout chached_spec_pairs_transform_path
|
# checkout cached_spec_pairs_transform_path
|
||||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
|
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
|
||||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
|
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue