[hotfix]fix bugs caused by refactored pipeline (#1133)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [hotfix]fix bugs caused by refactored pipeline
pull/1130/head^2
YuliangLiu0306 2022-06-17 17:54:15 +08:00 committed by GitHub
parent 789cad301b
commit 946dbd629d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 39 deletions

View File

@ -67,8 +67,8 @@ class NonPipelineSchedule(BaseSchedule):
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data = self.load_batch(data_iter)
if self.batch_data_process_func:
data, label = self.batch_data_process_func(batch_data)
if self.data_process_func:
data, label = self.data_process_func(batch_data)
else:
# if not batch data process func is given,
# then we regard the batch data as a simple tuple of (data, label)

View File

@ -141,6 +141,8 @@ class PipelineSchedule(BaseSchedule):
for element in data:
if isinstance(element, dict):
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
elif data_dict:
data_dict['label'] = element[offset:offset + self.microbatch_size]
if data_dict:
return data_dict
return [val[offset:offset + self.microbatch_size] for val in data]
@ -175,7 +177,10 @@ class PipelineSchedule(BaseSchedule):
elif isinstance(data, (list, tuple)):
return model(*data)
elif isinstance(data, dict):
return model(**data)
stage_output = None
if 'stage_output' in data:
stage_output = data.pop('stage_output')
return model(stage_output, **data)
else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
@ -204,41 +209,14 @@ class PipelineSchedule(BaseSchedule):
data = stage_output
_, label = micro_batch_data
elif isinstance(micro_batch_data, dict):
args = []
data = {}
label = {}
# we feed the stage output to args first
# then map each arg in args to its param name
if stage_output is not None:
if isinstance(stage_output, torch.Tensor):
args.append(stage_output)
elif isinstance(stage_output, (list, tuple)):
args.extend(stage_output)
else:
raise TypeError(
f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}"
)
# get all parameter names for the forward function of the model
fwd_sig = self._get_actual_forward_func(model)
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]
# build the kwargs for the forward function
for idx, param_name in enumerate(fwd_sig_param_name):
if idx < len(args):
data[param_name] = args[idx]
else:
if param_name in micro_batch_data:
data[param_name] = micro_batch_data[param_name]
# get the tensors for loss
loss_sig = inspect.signature(criterion)
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]
for param_name in loss_sig_param_name:
if param_name in micro_batch_data:
label[param_name] = micro_batch_data[param_name]
data['stage_output'] = stage_output
if 'label' in micro_batch_data:
label = micro_batch_data.pop('label')
else:
label = None
load_data = micro_batch_data
data.update(load_data)
return data, label
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):

View File

@ -66,8 +66,11 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
modified_args = []
for arg in args:
if isinstance(arg, torch.nn.Module):
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
arg = self._layer_spec_dict[id(arg)]
# if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build.
# if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context.
if id(arg) in self._layer_spec_dict:
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)
# to the same for the keyword arguments