|
|
@ -5,18 +5,17 @@ import torch.nn as nn
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Optional
|
|
|
|
from collections import defaultdict
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LINE_WIDTH = 108
|
|
|
|
LINE_WIDTH = 108
|
|
|
|
LINE = '-' * LINE_WIDTH + '\n'
|
|
|
|
LINE = '-' * LINE_WIDTH + '\n'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorDetector():
|
|
|
|
class TensorDetector():
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
show_info: bool = True,
|
|
|
|
show_info: bool = True,
|
|
|
|
log: str = None,
|
|
|
|
log: str = None,
|
|
|
|
include_cpu: bool = False,
|
|
|
|
include_cpu: bool = False,
|
|
|
|
module: Optional[nn.Module] = None
|
|
|
|
module: Optional[nn.Module] = None):
|
|
|
|
):
|
|
|
|
|
|
|
|
"""This class is a detector to detect tensor on different devices.
|
|
|
|
"""This class is a detector to detect tensor on different devices.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
@ -57,12 +56,12 @@ class TensorDetector():
|
|
|
|
|
|
|
|
|
|
|
|
def mem_format(self, real_memory_size):
|
|
|
|
def mem_format(self, real_memory_size):
|
|
|
|
# format the tensor memory into a reasonal magnitude
|
|
|
|
# format the tensor memory into a reasonal magnitude
|
|
|
|
if real_memory_size >= 2 ** 30:
|
|
|
|
if real_memory_size >= 2**30:
|
|
|
|
return str(real_memory_size / (2 ** 30)) + ' GB'
|
|
|
|
return str(real_memory_size / (2**30)) + ' GB'
|
|
|
|
if real_memory_size >= 2 ** 20:
|
|
|
|
if real_memory_size >= 2**20:
|
|
|
|
return str(real_memory_size / (2 ** 20)) + ' MB'
|
|
|
|
return str(real_memory_size / (2**20)) + ' MB'
|
|
|
|
if real_memory_size >= 2 ** 10:
|
|
|
|
if real_memory_size >= 2**10:
|
|
|
|
return str(real_memory_size / (2 ** 10)) + ' KB'
|
|
|
|
return str(real_memory_size / (2**10)) + ' KB'
|
|
|
|
return str(real_memory_size) + ' B'
|
|
|
|
return str(real_memory_size) + ' B'
|
|
|
|
|
|
|
|
|
|
|
|
def collect_tensors_state(self):
|
|
|
|
def collect_tensors_state(self):
|
|
|
@ -125,33 +124,30 @@ class TensorDetector():
|
|
|
|
minus = outdated + minus
|
|
|
|
minus = outdated + minus
|
|
|
|
if len(self.order) > 0:
|
|
|
|
if len(self.order) > 0:
|
|
|
|
for tensor_id in self.order:
|
|
|
|
for tensor_id in self.order:
|
|
|
|
self.info += template_format.format('+',
|
|
|
|
self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]),
|
|
|
|
str(self.tensor_info[tensor_id][0]),
|
|
|
|
str(self.tensor_info[tensor_id][1]),
|
|
|
|
str(self.tensor_info[tensor_id][1]),
|
|
|
|
str(tuple(self.tensor_info[tensor_id][2])),
|
|
|
|
str(tuple(self.tensor_info[tensor_id][2])),
|
|
|
|
str(self.tensor_info[tensor_id][3]),
|
|
|
|
str(self.tensor_info[tensor_id][3]),
|
|
|
|
str(self.tensor_info[tensor_id][4]),
|
|
|
|
str(self.tensor_info[tensor_id][4]),
|
|
|
|
str(self.tensor_info[tensor_id][5]))
|
|
|
|
str(self.tensor_info[tensor_id][5]))
|
|
|
|
|
|
|
|
self.info += '\n'
|
|
|
|
self.info += '\n'
|
|
|
|
if len(self.order) > 0 and len(minus) > 0:
|
|
|
|
if len(self.order) > 0 and len(minus) > 0:
|
|
|
|
self.info += '\n'
|
|
|
|
self.info += '\n'
|
|
|
|
if len(minus) > 0:
|
|
|
|
if len(minus) > 0:
|
|
|
|
for tensor_id in minus:
|
|
|
|
for tensor_id in minus:
|
|
|
|
self.info += template_format.format('-',
|
|
|
|
self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]),
|
|
|
|
str(self.saved_tensor_info[tensor_id][0]),
|
|
|
|
str(self.saved_tensor_info[tensor_id][1]),
|
|
|
|
str(self.saved_tensor_info[tensor_id][1]),
|
|
|
|
str(tuple(self.saved_tensor_info[tensor_id][2])),
|
|
|
|
str(tuple(self.saved_tensor_info[tensor_id][2])),
|
|
|
|
str(self.saved_tensor_info[tensor_id][3]),
|
|
|
|
str(self.saved_tensor_info[tensor_id][3]),
|
|
|
|
str(self.saved_tensor_info[tensor_id][4]),
|
|
|
|
str(self.saved_tensor_info[tensor_id][4]),
|
|
|
|
str(self.saved_tensor_info[tensor_id][5]))
|
|
|
|
str(self.saved_tensor_info[tensor_id][5]))
|
|
|
|
|
|
|
|
self.info += '\n'
|
|
|
|
self.info += '\n'
|
|
|
|
# deleted the updated tensor
|
|
|
|
# deleted the updated tensor
|
|
|
|
self.saved_tensor_info.pop(tensor_id)
|
|
|
|
self.saved_tensor_info.pop(tensor_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# trace where is the detect()
|
|
|
|
# trace where is the detect()
|
|
|
|
locate_info = inspect.stack()[2]
|
|
|
|
locate_info = inspect.stack()[2]
|
|
|
|
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
|
|
|
|
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
|
|
|
|
|
|
|
|
|
|
|
|
self.info += LINE
|
|
|
|
self.info += LINE
|
|
|
|
self.info += f"Detect Location: {locate_msg}\n"
|
|
|
|
self.info += f"Detect Location: {locate_msg}\n"
|
|
|
@ -168,7 +164,7 @@ class TensorDetector():
|
|
|
|
with open(self.log + '.log', 'a') as f:
|
|
|
|
with open(self.log + '.log', 'a') as f:
|
|
|
|
f.write(self.info)
|
|
|
|
f.write(self.info)
|
|
|
|
|
|
|
|
|
|
|
|
def detect(self, include_cpu = False):
|
|
|
|
def detect(self, include_cpu=False):
|
|
|
|
self.include_cpu = include_cpu
|
|
|
|
self.include_cpu = include_cpu
|
|
|
|
self.collect_tensors_state()
|
|
|
|
self.collect_tensors_state()
|
|
|
|
self.print_tensors_state()
|
|
|
|
self.print_tensors_state()
|
|
|
|