From 0dea1407604d4ac22bb7333ad6d213b430a99e81 Mon Sep 17 00:00:00 2001 From: ver217 Date: Sun, 24 Apr 2022 15:03:04 +0800 Subject: [PATCH] [hotfix] add deconstructor for stateful tensor (#848) * add deconstructor for stateful tensor * fix colo init context --- colossalai/gemini/gemini_context.py | 25 ++++++++++++--------- colossalai/gemini/stateful_tensor.py | 5 +++++ colossalai/utils/model/colo_init_context.py | 2 +- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/colossalai/gemini/gemini_context.py b/colossalai/gemini/gemini_context.py index aeade031f..98c8a914e 100644 --- a/colossalai/gemini/gemini_context.py +++ b/colossalai/gemini/gemini_context.py @@ -6,7 +6,7 @@ class GeminiMemoryManager(object): def __init__(self, states_cls: EnumMeta): super().__init__() self.states_cls = states_cls - self._cnter = 0 # the counter of instances + self._cnter = 0 # the counter of instances self.total_mem = dict() self.state_mem = dict() @@ -20,10 +20,10 @@ class GeminiMemoryManager(object): return self._cnter def reset(self): - self._cnter = 0 # the counter of instances + self._cnter = 0 # the counter of instances - self.total_mem['cpu'] = 0 # memory occupation of instances in cpu - self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda + self.total_mem['cpu'] = 0 # memory occupation of instances in cpu + self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda # memory conditions for all states for state in self.states_cls: @@ -33,13 +33,16 @@ class GeminiMemoryManager(object): def register_new_instance(self): self._cnter += 1 + def delete_instance(self): + self._cnter -= 1 + def print_info(self): - print( - f"Total number: {self.total_number}", - f"Total CPU memory occupation: {self.total_mem['cpu']}", - f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n') + print(f"Total number: {self.total_number}", + f"Total CPU memory occupation: {self.total_mem['cpu']}", + f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", + sep='\n') for state in self.states_cls: - print( - f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", - f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n') + print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", + f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", + sep='\n') diff --git a/colossalai/gemini/stateful_tensor.py b/colossalai/gemini/stateful_tensor.py index d6ab29cbe..18fc8fd14 100644 --- a/colossalai/gemini/stateful_tensor.py +++ b/colossalai/gemini/stateful_tensor.py @@ -202,3 +202,8 @@ class StatefulTensor(object): # update the information of each state manager.state_mem[from_type][state] -= size manager.state_mem[to_type][state] += size + + def __del__(self): + self.set_null() + StatefulTensor.GST_MGR.delete_instance() + del self diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 429bf2175..1e9efec0a 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -12,7 +12,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): super().__init__() self._lazy_memory_allocate = lazy_memory_allocate - def _post_init_method(self, module: torch.nn.Module): + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ The function to call at the end of the constructor of each module. FIXME(fjr) The module may be passed to this function multiple times?