diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index 687e220bd..a144db6a0 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -48,9 +48,11 @@ class BaseSchedule(ABC): if isinstance(data, torch.Tensor): return data.size(0) elif isinstance(data, (list, tuple)): + if isinstance(data[0], dict): + return data[0][list(data[0].keys())[0]].size(0) return data[0].size(0) elif isinstance(data, dict): - return data[next(data.keys())].size(0) + return data[list(data.keys())[0]].size(0) def load_batch(self, data_iter, to_gpu=True): """Loads a batch from data iterator. It returns the data and labels which are diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index bcd91ea6d..1063c6c97 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -137,6 +137,12 @@ class PipelineSchedule(BaseSchedule): if isinstance(data, torch.Tensor): return data[offset:offset + self.microbatch_size] elif isinstance(data, (list, tuple)): + data_dict = {} + for element in data: + if isinstance(element, dict): + data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()}) + if data_dict: + return data_dict return [val[offset:offset + self.microbatch_size] for val in data] elif isinstance(data, dict): return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()} @@ -216,7 +222,7 @@ class PipelineSchedule(BaseSchedule): # get all parameter names for the forward function of the model fwd_sig = self._get_actual_forward_func(model) - fwd_sig_param_name = [p.name for p in fwd_sig.values()] + fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()] # build the kwargs for the forward function for idx, param_name in enumerate(fwd_sig_param_name): @@ -228,7 +234,7 @@ class PipelineSchedule(BaseSchedule): # get the tensors for loss loss_sig = inspect.signature(criterion) - loss_sig_param_name = [p.name for p in loss_sig.values()] + loss_sig_param_name = [p.name for p in loss_sig.parameters.values()] for param_name in loss_sig_param_name: if param_name in micro_batch_data: diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py index 317ac0c21..826da2055 100644 --- a/colossalai/pipeline/pipelinable.py +++ b/colossalai/pipeline/pipelinable.py @@ -12,7 +12,7 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): A context manager to split the model into pipeline stages. """ - def __init__(self, policy: str="balanced"): + def __init__(self, policy: str = "balanced"): super().__init__() self._layer_spec_dict = {} self._root_children = None @@ -61,11 +61,12 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): """ # iterate over the positional arguments # to check if an argument is a torch Module - # if found any torch Module, replace it with its layer spec + # if found any torch Module, replace it with its layer spec # for storage purpose modified_args = [] for arg in args: if isinstance(arg, torch.nn.Module): + # (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself. arg = self._layer_spec_dict[id(arg)] modified_args.append(arg) @@ -255,6 +256,3 @@ class PipelinableModel(torch.nn.Module): input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) return input_tensor - - -