diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/engine/schedule/_non_pipeline_schedule.py index e6e31a195..8e41df53b 100644 --- a/colossalai/engine/schedule/_non_pipeline_schedule.py +++ b/colossalai/engine/schedule/_non_pipeline_schedule.py @@ -67,8 +67,8 @@ class NonPipelineSchedule(BaseSchedule): "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." batch_data = self.load_batch(data_iter) - if self.batch_data_process_func: - data, label = self.batch_data_process_func(batch_data) + if self.data_process_func: + data, label = self.data_process_func(batch_data) else: # if not batch data process func is given, # then we regard the batch data as a simple tuple of (data, label) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 1063c6c97..6f1d755b8 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -141,6 +141,8 @@ class PipelineSchedule(BaseSchedule): for element in data: if isinstance(element, dict): data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()}) + elif data_dict: + data_dict['label'] = element[offset:offset + self.microbatch_size] if data_dict: return data_dict return [val[offset:offset + self.microbatch_size] for val in data] @@ -175,7 +177,10 @@ class PipelineSchedule(BaseSchedule): elif isinstance(data, (list, tuple)): return model(*data) elif isinstance(data, dict): - return model(**data) + stage_output = None + if 'stage_output' in data: + stage_output = data.pop('stage_output') + return model(stage_output, **data) else: raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") @@ -204,41 +209,14 @@ class PipelineSchedule(BaseSchedule): data = stage_output _, label = micro_batch_data elif isinstance(micro_batch_data, dict): - args = [] data = {} - label = {} - - # we feed the stage output to args first - # then map each arg in args to its param name - if stage_output is not None: - if isinstance(stage_output, torch.Tensor): - args.append(stage_output) - elif isinstance(stage_output, (list, tuple)): - args.extend(stage_output) - else: - raise TypeError( - f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}" - ) - - # get all parameter names for the forward function of the model - fwd_sig = self._get_actual_forward_func(model) - fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()] - - # build the kwargs for the forward function - for idx, param_name in enumerate(fwd_sig_param_name): - if idx < len(args): - data[param_name] = args[idx] - else: - if param_name in micro_batch_data: - data[param_name] = micro_batch_data[param_name] - - # get the tensors for loss - loss_sig = inspect.signature(criterion) - loss_sig_param_name = [p.name for p in loss_sig.parameters.values()] - - for param_name in loss_sig_param_name: - if param_name in micro_batch_data: - label[param_name] = micro_batch_data[param_name] + data['stage_output'] = stage_output + if 'label' in micro_batch_data: + label = micro_batch_data.pop('label') + else: + label = None + load_data = micro_batch_data + data.update(load_data) return data, label def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py index 826da2055..d7db77c9d 100644 --- a/colossalai/pipeline/pipelinable.py +++ b/colossalai/pipeline/pipelinable.py @@ -66,8 +66,11 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): modified_args = [] for arg in args: if isinstance(arg, torch.nn.Module): - # (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself. - arg = self._layer_spec_dict[id(arg)] + # if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build. + # if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context. + if id(arg) in self._layer_spec_dict: + arg = self._layer_spec_dict[id(arg)] + modified_args.append(arg) # to the same for the keyword arguments