mirror of https://github.com/hpcaitech/ColossalAI
[Pipeline]Adapt to Pipelinable OPT (#1782)
parent
27de252334
commit
4df0194976
|
@ -6,6 +6,7 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.nn.layer.utils import CheckpointModule
|
||||
from typing import List
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
def _binary_partition(weights: List, start: int, end: int):
|
||||
"""Returns the binary partition position of `weights`, given the start
|
||||
|
@ -159,8 +160,10 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
|
|||
kwargs_offset = 0
|
||||
elif isinstance(input_tensor, torch.Tensor):
|
||||
kwargs_offset = 1
|
||||
else:
|
||||
assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
|
||||
elif isinstance(input_tensor, (tuple, OrderedDict)):
|
||||
#assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
|
||||
# Huggingface will take their own structures based on OrderedDict as the output
|
||||
# between layers so we've to close this check.
|
||||
kwargs_offset = len(input_tensor)
|
||||
args_name_list = list(sig.parameters.keys())
|
||||
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}
|
||||
|
|
Loading…
Reference in New Issue