mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] add deconstructor for stateful tensor (#848)
* add deconstructor for stateful tensor * fix colo init contextpull/850/head
parent
0f7ed8c192
commit
0dea140760
|
@ -6,7 +6,7 @@ class GeminiMemoryManager(object):
|
|||
def __init__(self, states_cls: EnumMeta):
|
||||
super().__init__()
|
||||
self.states_cls = states_cls
|
||||
self._cnter = 0 # the counter of instances
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem = dict()
|
||||
self.state_mem = dict()
|
||||
|
@ -20,10 +20,10 @@ class GeminiMemoryManager(object):
|
|||
return self._cnter
|
||||
|
||||
def reset(self):
|
||||
self._cnter = 0 # the counter of instances
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
||||
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
||||
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
||||
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
||||
|
||||
# memory conditions for all states
|
||||
for state in self.states_cls:
|
||||
|
@ -33,13 +33,16 @@ class GeminiMemoryManager(object):
|
|||
def register_new_instance(self):
|
||||
self._cnter += 1
|
||||
|
||||
def delete_instance(self):
|
||||
self._cnter -= 1
|
||||
|
||||
def print_info(self):
|
||||
print(
|
||||
f"Total number: {self.total_number}",
|
||||
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
||||
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n')
|
||||
print(f"Total number: {self.total_number}",
|
||||
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
||||
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
|
||||
sep='\n')
|
||||
|
||||
for state in self.states_cls:
|
||||
print(
|
||||
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
||||
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n')
|
||||
print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
||||
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
|
||||
sep='\n')
|
||||
|
|
|
@ -202,3 +202,8 @@ class StatefulTensor(object):
|
|||
# update the information of each state
|
||||
manager.state_mem[from_type][state] -= size
|
||||
manager.state_mem[to_type][state] += size
|
||||
|
||||
def __del__(self):
|
||||
self.set_null()
|
||||
StatefulTensor.GST_MGR.delete_instance()
|
||||
del self
|
||||
|
|
|
@ -12,7 +12,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
super().__init__()
|
||||
self._lazy_memory_allocate = lazy_memory_allocate
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module):
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
FIXME(fjr) The module may be passed to this function multiple times?
|
||||
|
|
Loading…
Reference in New Issue