[hotfix] add deconstructor for stateful tensor (#848)

* add deconstructor for stateful tensor

* fix colo init context
pull/850/head
ver217 2022-04-24 15:03:04 +08:00 committed by GitHub
parent 0f7ed8c192
commit 0dea140760
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 12 deletions

View File

@ -33,13 +33,16 @@ class GeminiMemoryManager(object):
def register_new_instance(self): def register_new_instance(self):
self._cnter += 1 self._cnter += 1
def delete_instance(self):
self._cnter -= 1
def print_info(self): def print_info(self):
print( print(f"Total number: {self.total_number}",
f"Total number: {self.total_number}",
f"Total CPU memory occupation: {self.total_mem['cpu']}", f"Total CPU memory occupation: {self.total_mem['cpu']}",
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n') f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
sep='\n')
for state in self.states_cls: for state in self.states_cls:
print( print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}", f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n') sep='\n')

View File

@ -202,3 +202,8 @@ class StatefulTensor(object):
# update the information of each state # update the information of each state
manager.state_mem[from_type][state] -= size manager.state_mem[from_type][state] -= size
manager.state_mem[to_type][state] += size manager.state_mem[to_type][state] += size
def __del__(self):
self.set_null()
StatefulTensor.GST_MGR.delete_instance()
del self

View File

@ -12,7 +12,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__() super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate self._lazy_memory_allocate = lazy_memory_allocate
def _post_init_method(self, module: torch.nn.Module): def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
""" """
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times? FIXME(fjr) The module may be passed to this function multiple times?