mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
49 lines
1.4 KiB
49 lines
1.4 KiB
from enum import EnumMeta
|
|
|
|
|
|
class GeminiMemoryManager(object):
|
|
|
|
def __init__(self, states_cls: EnumMeta):
|
|
super().__init__()
|
|
self.states_cls = states_cls
|
|
self._cnter = 0 # the counter of instances
|
|
|
|
self.total_mem = dict()
|
|
self.state_mem = dict()
|
|
self.state_mem['cpu'] = dict()
|
|
self.state_mem['cuda'] = dict()
|
|
|
|
self.reset()
|
|
|
|
@property
|
|
def total_number(self):
|
|
return self._cnter
|
|
|
|
def reset(self):
|
|
self._cnter = 0 # the counter of instances
|
|
|
|
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
|
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
|
|
|
# memory conditions for all states
|
|
for state in self.states_cls:
|
|
self.state_mem['cpu'][state] = 0
|
|
self.state_mem['cuda'][state] = 0
|
|
|
|
def register_new_instance(self):
|
|
self._cnter += 1
|
|
|
|
def delete_instance(self):
|
|
self._cnter -= 1
|
|
|
|
def print_info(self):
|
|
print(f"Total number: {self.total_number}",
|
|
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
|
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
|
|
sep='\n')
|
|
|
|
for state in self.states_cls:
|
|
print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
|
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
|
|
sep='\n')
|