From 4df01949760e35b286e6a4493c8ba15fa4467146 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Tue, 1 Nov 2022 14:18:50 +0800 Subject: [PATCH] [Pipeline]Adapt to Pipelinable OPT (#1782) --- colossalai/pipeline/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py index 5afed0225..df7226644 100644 --- a/colossalai/pipeline/utils.py +++ b/colossalai/pipeline/utils.py @@ -6,6 +6,7 @@ from colossalai.logging import get_dist_logger from colossalai.nn.layer.utils import CheckpointModule from typing import List +from collections import OrderedDict def _binary_partition(weights: List, start: int, end: int): """Returns the binary partition position of `weights`, given the start @@ -159,8 +160,10 @@ def build_kwargs_for_module(function, input_tensor, kw_dict): kwargs_offset = 0 elif isinstance(input_tensor, torch.Tensor): kwargs_offset = 1 - else: - assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + elif isinstance(input_tensor, (tuple, OrderedDict)): + #assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + # Huggingface will take their own structures based on OrderedDict as the output + # between layers so we've to close this check. kwargs_offset = len(input_tensor) args_name_list = list(sig.parameters.keys()) kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}