[hotfix] add deconstructor for stateful tensor (#848)

* add deconstructor for stateful tensor

* fix colo init context
pull/850/head
ver217 3 years ago committed by GitHub
parent 0f7ed8c192
commit 0dea140760
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,7 +6,7 @@ class GeminiMemoryManager(object):
def __init__(self, states_cls: EnumMeta):
super().__init__()
self.states_cls = states_cls
self._cnter = 0 # the counter of instances
self._cnter = 0 # the counter of instances
self.total_mem = dict()
self.state_mem = dict()
@ -20,10 +20,10 @@ class GeminiMemoryManager(object):
return self._cnter
def reset(self):
self._cnter = 0 # the counter of instances
self._cnter = 0 # the counter of instances
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
# memory conditions for all states
for state in self.states_cls:
@ -33,13 +33,16 @@ class GeminiMemoryManager(object):
def register_new_instance(self):
self._cnter += 1
def delete_instance(self):
self._cnter -= 1
def print_info(self):
print(
f"Total number: {self.total_number}",
f"Total CPU memory occupation: {self.total_mem['cpu']}",
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n')
print(f"Total number: {self.total_number}",
f"Total CPU memory occupation: {self.total_mem['cpu']}",
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
sep='\n')
for state in self.states_cls:
print(
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n')
print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
sep='\n')

@ -202,3 +202,8 @@ class StatefulTensor(object):
# update the information of each state
manager.state_mem[from_type][state] -= size
manager.state_mem[to_type][state] += size
def __del__(self):
self.set_null()
StatefulTensor.GST_MGR.delete_instance()
del self

@ -12,7 +12,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
super().__init__()
self._lazy_memory_allocate = lazy_memory_allocate
def _post_init_method(self, module: torch.nn.Module):
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
FIXME(fjr) The module may be passed to this function multiple times?

Loading…
Cancel
Save