From cfb41297ff419228f93eb2fdc1b19a1186859a66 Mon Sep 17 00:00:00 2001 From: yuxuan-lou <83441848+yuxuan-lou@users.noreply.github.com> Date: Thu, 31 Mar 2022 15:46:11 +0800 Subject: [PATCH] 'fix/format' (#573) --- colossalai/engine/ophooks/_memtracer_ophook.py | 4 ++-- colossalai/engine/schedule/_base_schedule.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/engine/ophooks/_memtracer_ophook.py b/colossalai/engine/ophooks/_memtracer_ophook.py index b4de22f61..9535ded61 100644 --- a/colossalai/engine/ophooks/_memtracer_ophook.py +++ b/colossalai/engine/ophooks/_memtracer_ophook.py @@ -106,7 +106,7 @@ class MemTracerOpHook(BaseOpHook): # output file info self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl") home_dir = Path.home() - with open (home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f: + with open(home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f: pickle.dump(self.async_mem_monitor.state_dict, f) self._count += 1 self._logger.debug(f"data file has been refreshed {self._count} times") @@ -115,4 +115,4 @@ class MemTracerOpHook(BaseOpHook): def save_results(self, data_file: Union[str, Path]): with open(data_file, "w") as f: - f.write(json.dumps(self.async_mem_monitor.state_dict)) \ No newline at end of file + f.write(json.dumps(self.async_mem_monitor.state_dict)) diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index b23a4d0a6..0501d878b 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -85,8 +85,7 @@ class BaseSchedule(ABC): data_iter: Iterable, forward_only: bool, return_loss: bool = True, - return_output_label: bool = True - ): + return_output_label: bool = True): """The process function over a batch of dataset for training or evaluation. Args: @@ -107,8 +106,9 @@ class BaseSchedule(ABC): @staticmethod def _call_engine_criterion(engine, outputs, labels): - assert isinstance(outputs, (torch.Tensor, list, tuple) - ), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' + assert isinstance( + outputs, + (torch.Tensor, list, tuple)), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}' if isinstance(outputs, torch.Tensor): outputs = (outputs,) if isinstance(labels, torch.Tensor):