From 0387a47e63520bf112f80d094b64e1ae5890d525 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 29 Aug 2023 11:25:05 +0800 Subject: [PATCH] [shardformer] fix emerged bugs after updating transformers (#4526) --- colossalai/pipeline/schedule/_utils.py | 5 ++++- tests/test_shardformer/test_model/_utils.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/colossalai/pipeline/schedule/_utils.py b/colossalai/pipeline/schedule/_utils.py index 3ed923927..5cd934b76 100644 --- a/colossalai/pipeline/schedule/_utils.py +++ b/colossalai/pipeline/schedule/_utils.py @@ -123,7 +123,10 @@ def merge_batch(data: List[Any]) -> Any: merged_data = [] for elem_batch in zip(*flattened_data): if isinstance(elem_batch[0], torch.Tensor): - merged_data.append(torch.cat(elem_batch, dim=0)) + if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs + merged_data.append(None) + else: + merged_data.append(torch.cat(elem_batch, dim=0)) else: merged_data.append(list(elem_batch)) return tree_unflatten(merged_data, tree_spec) diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index 811471bec..803afc48a 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -195,7 +195,11 @@ def check_output_hidden_state(org_output: Tensor, sharded_hidden_state = sharded_output.last_hidden_state if stage_manager and stage_manager.is_last_stage(): - sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim) + pipeline_output = sharded_output['outputs'] + if isinstance(pipeline_output, List): + sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim) + else: + sharded_hidden_state = pipeline_output.last_hidden_state assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \ f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"