|
|
|
@ -5,18 +5,17 @@ import torch.nn as nn
|
|
|
|
|
from typing import Optional |
|
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LINE_WIDTH = 108 |
|
|
|
|
LINE = '-' * LINE_WIDTH + '\n' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorDetector(): |
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
show_info: bool = True, |
|
|
|
|
log: str = None, |
|
|
|
|
include_cpu: bool = False, |
|
|
|
|
module: Optional[nn.Module] = None |
|
|
|
|
): |
|
|
|
|
module: Optional[nn.Module] = None): |
|
|
|
|
"""This class is a detector to detect tensor on different devices. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -28,7 +27,7 @@ class TensorDetector():
|
|
|
|
|
""" |
|
|
|
|
self.show_info = show_info |
|
|
|
|
self.log = log |
|
|
|
|
self.include_cpu = include_cpu |
|
|
|
|
self.include_cpu = include_cpu |
|
|
|
|
self.tensor_info = defaultdict(list) |
|
|
|
|
self.saved_tensor_info = defaultdict(list) |
|
|
|
|
self.order = [] |
|
|
|
@ -57,13 +56,13 @@ class TensorDetector():
|
|
|
|
|
|
|
|
|
|
def mem_format(self, real_memory_size): |
|
|
|
|
# format the tensor memory into a reasonal magnitude |
|
|
|
|
if real_memory_size >= 2 ** 30: |
|
|
|
|
return str(real_memory_size / (2 ** 30)) + ' GB' |
|
|
|
|
if real_memory_size >= 2 ** 20: |
|
|
|
|
return str(real_memory_size / (2 ** 20)) + ' MB' |
|
|
|
|
if real_memory_size >= 2 ** 10: |
|
|
|
|
return str(real_memory_size / (2 ** 10)) + ' KB' |
|
|
|
|
return str(real_memory_size) + ' B' |
|
|
|
|
if real_memory_size >= 2**30: |
|
|
|
|
return str(real_memory_size / (2**30)) + ' GB' |
|
|
|
|
if real_memory_size >= 2**20: |
|
|
|
|
return str(real_memory_size / (2**20)) + ' MB' |
|
|
|
|
if real_memory_size >= 2**10: |
|
|
|
|
return str(real_memory_size / (2**10)) + ' KB' |
|
|
|
|
return str(real_memory_size) + ' B' |
|
|
|
|
|
|
|
|
|
def collect_tensors_state(self): |
|
|
|
|
for obj in gc.get_objects(): |
|
|
|
@ -74,11 +73,11 @@ class TensorDetector():
|
|
|
|
|
self.detected.append(id(obj)) |
|
|
|
|
# skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch |
|
|
|
|
if id(obj) not in self.tensor_info: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
name = type(obj).__name__ |
|
|
|
|
# after backward, we want to update the records, to show you the change |
|
|
|
|
if isinstance(self.module, nn.Module) and name == 'Parameter': |
|
|
|
|
if obj.grad is not None: |
|
|
|
|
if obj.grad is not None: |
|
|
|
|
# with grad attached |
|
|
|
|
for par_name, param in self.module.named_parameters(): |
|
|
|
|
if param.requires_grad and param.grad.equal(obj.grad): |
|
|
|
@ -88,7 +87,7 @@ class TensorDetector():
|
|
|
|
|
# there will be no new paramters created during running |
|
|
|
|
# so it must be in saved_tensor_info |
|
|
|
|
continue |
|
|
|
|
# we can also marked common tensors as tensor(with grad) |
|
|
|
|
# we can also marked common tensors as tensor(with grad) |
|
|
|
|
if name == 'Tensor' and (obj.is_leaf or obj.retains_grad): |
|
|
|
|
if obj.grad is not None: |
|
|
|
|
name = name + ' (with grad)' |
|
|
|
@ -104,7 +103,7 @@ class TensorDetector():
|
|
|
|
|
self.tensor_info[id(obj)].append(obj.dtype) |
|
|
|
|
self.tensor_info[id(obj)].append(self.get_tensor_mem(obj)) |
|
|
|
|
# recorded the order we got the tensor |
|
|
|
|
# by this we can guess the tensor easily |
|
|
|
|
# by this we can guess the tensor easily |
|
|
|
|
# it will record every tensor updated this turn |
|
|
|
|
self.order.append(id(obj)) |
|
|
|
|
# recorded all different devices |
|
|
|
@ -114,7 +113,7 @@ class TensorDetector():
|
|
|
|
|
def print_tensors_state(self): |
|
|
|
|
template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}' |
|
|
|
|
self.info += LINE |
|
|
|
|
self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') |
|
|
|
|
self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') |
|
|
|
|
self.info += '\n' |
|
|
|
|
self.info += LINE |
|
|
|
|
|
|
|
|
@ -122,36 +121,33 @@ class TensorDetector():
|
|
|
|
|
# it should be updated in the saved_tensor_info as well |
|
|
|
|
outdated = [x for x in self.saved_tensor_info.keys() if x in self.order] |
|
|
|
|
minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected] |
|
|
|
|
minus = outdated + minus |
|
|
|
|
minus = outdated + minus |
|
|
|
|
if len(self.order) > 0: |
|
|
|
|
for tensor_id in self.order: |
|
|
|
|
self.info += template_format.format('+', |
|
|
|
|
str(self.tensor_info[tensor_id][0]), |
|
|
|
|
str(self.tensor_info[tensor_id][1]), |
|
|
|
|
str(tuple(self.tensor_info[tensor_id][2])), |
|
|
|
|
str(self.tensor_info[tensor_id][3]), |
|
|
|
|
str(self.tensor_info[tensor_id][4]), |
|
|
|
|
str(self.tensor_info[tensor_id][5])) |
|
|
|
|
self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]), |
|
|
|
|
str(self.tensor_info[tensor_id][1]), |
|
|
|
|
str(tuple(self.tensor_info[tensor_id][2])), |
|
|
|
|
str(self.tensor_info[tensor_id][3]), |
|
|
|
|
str(self.tensor_info[tensor_id][4]), |
|
|
|
|
str(self.tensor_info[tensor_id][5])) |
|
|
|
|
self.info += '\n' |
|
|
|
|
if len(self.order) > 0 and len(minus) > 0: |
|
|
|
|
self.info += '\n' |
|
|
|
|
if len(minus) > 0: |
|
|
|
|
for tensor_id in minus: |
|
|
|
|
self.info += template_format.format('-', |
|
|
|
|
str(self.saved_tensor_info[tensor_id][0]), |
|
|
|
|
str(self.saved_tensor_info[tensor_id][1]), |
|
|
|
|
str(tuple(self.saved_tensor_info[tensor_id][2])), |
|
|
|
|
str(self.saved_tensor_info[tensor_id][3]), |
|
|
|
|
str(self.saved_tensor_info[tensor_id][4]), |
|
|
|
|
str(self.saved_tensor_info[tensor_id][5])) |
|
|
|
|
self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]), |
|
|
|
|
str(self.saved_tensor_info[tensor_id][1]), |
|
|
|
|
str(tuple(self.saved_tensor_info[tensor_id][2])), |
|
|
|
|
str(self.saved_tensor_info[tensor_id][3]), |
|
|
|
|
str(self.saved_tensor_info[tensor_id][4]), |
|
|
|
|
str(self.saved_tensor_info[tensor_id][5])) |
|
|
|
|
self.info += '\n' |
|
|
|
|
# deleted the updated tensor |
|
|
|
|
self.saved_tensor_info.pop(tensor_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# trace where is the detect() |
|
|
|
|
locate_info = inspect.stack()[2] |
|
|
|
|
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno) |
|
|
|
|
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno) |
|
|
|
|
|
|
|
|
|
self.info += LINE |
|
|
|
|
self.info += f"Detect Location: {locate_msg}\n" |
|
|
|
@ -167,8 +163,8 @@ class TensorDetector():
|
|
|
|
|
if self.log is not None: |
|
|
|
|
with open(self.log + '.log', 'a') as f: |
|
|
|
|
f.write(self.info) |
|
|
|
|
|
|
|
|
|
def detect(self, include_cpu = False): |
|
|
|
|
|
|
|
|
|
def detect(self, include_cpu=False): |
|
|
|
|
self.include_cpu = include_cpu |
|
|
|
|
self.collect_tensors_state() |
|
|
|
|
self.print_tensors_state() |
|
|
|
@ -180,4 +176,4 @@ class TensorDetector():
|
|
|
|
|
|
|
|
|
|
def close(self): |
|
|
|
|
self.saved_tensor_info.clear() |
|
|
|
|
self.module = None |
|
|
|
|
self.module = None |
|
|
|
|