[hotfix] fx get comm size bugs (#1233)

* init a checkpoint dir

* [checkpoint]support resume for cosinewarmuplr

* [checkpoint]add unit test

* fix some bugs but still not OK

* fix bugs

* make it faster

* [checkpoint]support generalized scheduler

* polish

* [tensor] torch function return colotensor

* polish

* fix bugs

* remove debug info

* polish

* polish

* [tensor] test_model pass unittests

* polish

* [hotfix] fx get comm size bug

Co-authored-by: ZhaoYi1222 <zhaoyi9499@gmail.com>
pull/1232/head
Jiarui Fang 2022-07-08 10:54:41 +08:00 committed by GitHub
parent 42ab36b762
commit 0e199d71e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 8 deletions

View File

@ -15,8 +15,8 @@ def get_comm_size(prev_partition, next_partition):
# If a node has input nodes from the parent partition,
# the output size of those input nodes will be counted
# and added to comm_size
parent_node_names = [n.name for n in parent_partition.graph.nodes]
for node in child_partition.graph.nodes:
parent_node_names = [n.name for n in prev_partition.graph.nodes]
for node in next_partition.graph.nodes:
input_nodes: Dict[Node, None] = {}
map_arg(node.args, lambda n: input_nodes.setdefault(n))
map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))

View File

@ -74,8 +74,9 @@ def run_1d_hybrid_tp(model_name):
continue
# print(name)
# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p, pg)
# TODO(jiaruifang) has bug if open the following 2 comments
# if 'classifier' in name and 'weight' in name:
# init_1d_row_linear(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
if 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
@ -152,7 +153,6 @@ def run_1d_hybrid_tp(model_name):
# Test the overrided parameters() and named_parameters() member functions
@pytest.mark.skip
def test_model_parameters():
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
@ -186,9 +186,8 @@ def test_model_parameters():
assert param_cnt == 2
@pytest.mark.skip
# @pytest.mark.skip
def test_colo_optimizer():
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(1)
@ -323,7 +322,6 @@ def run_model_dist(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("under development")
@rerun_if_address_is_in_use()
def test_model(world_size):
run_func = partial(run_model_dist, world_size=world_size, port=free_port())