Browse Source

[NFC] polish colossalai/utils/tensor_detector/tensor_detector.py code style (#1566)

pull/1550/head
LuGY 2 years ago committed by Frank Lee
parent
commit
c7d4932956
  1. 68
      colossalai/utils/tensor_detector/tensor_detector.py

68
colossalai/utils/tensor_detector/tensor_detector.py

@ -5,18 +5,17 @@ import torch.nn as nn
from typing import Optional from typing import Optional
from collections import defaultdict from collections import defaultdict
LINE_WIDTH = 108 LINE_WIDTH = 108
LINE = '-' * LINE_WIDTH + '\n' LINE = '-' * LINE_WIDTH + '\n'
class TensorDetector(): class TensorDetector():
def __init__(self, def __init__(self,
show_info: bool = True, show_info: bool = True,
log: str = None, log: str = None,
include_cpu: bool = False, include_cpu: bool = False,
module: Optional[nn.Module] = None module: Optional[nn.Module] = None):
):
"""This class is a detector to detect tensor on different devices. """This class is a detector to detect tensor on different devices.
Args: Args:
@ -28,7 +27,7 @@ class TensorDetector():
""" """
self.show_info = show_info self.show_info = show_info
self.log = log self.log = log
self.include_cpu = include_cpu self.include_cpu = include_cpu
self.tensor_info = defaultdict(list) self.tensor_info = defaultdict(list)
self.saved_tensor_info = defaultdict(list) self.saved_tensor_info = defaultdict(list)
self.order = [] self.order = []
@ -57,13 +56,13 @@ class TensorDetector():
def mem_format(self, real_memory_size): def mem_format(self, real_memory_size):
# format the tensor memory into a reasonal magnitude # format the tensor memory into a reasonal magnitude
if real_memory_size >= 2 ** 30: if real_memory_size >= 2**30:
return str(real_memory_size / (2 ** 30)) + ' GB' return str(real_memory_size / (2**30)) + ' GB'
if real_memory_size >= 2 ** 20: if real_memory_size >= 2**20:
return str(real_memory_size / (2 ** 20)) + ' MB' return str(real_memory_size / (2**20)) + ' MB'
if real_memory_size >= 2 ** 10: if real_memory_size >= 2**10:
return str(real_memory_size / (2 ** 10)) + ' KB' return str(real_memory_size / (2**10)) + ' KB'
return str(real_memory_size) + ' B' return str(real_memory_size) + ' B'
def collect_tensors_state(self): def collect_tensors_state(self):
for obj in gc.get_objects(): for obj in gc.get_objects():
@ -74,11 +73,11 @@ class TensorDetector():
self.detected.append(id(obj)) self.detected.append(id(obj))
# skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch # skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch
if id(obj) not in self.tensor_info: if id(obj) not in self.tensor_info:
name = type(obj).__name__ name = type(obj).__name__
# after backward, we want to update the records, to show you the change # after backward, we want to update the records, to show you the change
if isinstance(self.module, nn.Module) and name == 'Parameter': if isinstance(self.module, nn.Module) and name == 'Parameter':
if obj.grad is not None: if obj.grad is not None:
# with grad attached # with grad attached
for par_name, param in self.module.named_parameters(): for par_name, param in self.module.named_parameters():
if param.requires_grad and param.grad.equal(obj.grad): if param.requires_grad and param.grad.equal(obj.grad):
@ -88,7 +87,7 @@ class TensorDetector():
# there will be no new paramters created during running # there will be no new paramters created during running
# so it must be in saved_tensor_info # so it must be in saved_tensor_info
continue continue
# we can also marked common tensors as tensor(with grad) # we can also marked common tensors as tensor(with grad)
if name == 'Tensor' and (obj.is_leaf or obj.retains_grad): if name == 'Tensor' and (obj.is_leaf or obj.retains_grad):
if obj.grad is not None: if obj.grad is not None:
name = name + ' (with grad)' name = name + ' (with grad)'
@ -104,7 +103,7 @@ class TensorDetector():
self.tensor_info[id(obj)].append(obj.dtype) self.tensor_info[id(obj)].append(obj.dtype)
self.tensor_info[id(obj)].append(self.get_tensor_mem(obj)) self.tensor_info[id(obj)].append(self.get_tensor_mem(obj))
# recorded the order we got the tensor # recorded the order we got the tensor
# by this we can guess the tensor easily # by this we can guess the tensor easily
# it will record every tensor updated this turn # it will record every tensor updated this turn
self.order.append(id(obj)) self.order.append(id(obj))
# recorded all different devices # recorded all different devices
@ -114,7 +113,7 @@ class TensorDetector():
def print_tensors_state(self): def print_tensors_state(self):
template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}' template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}'
self.info += LINE self.info += LINE
self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem')
self.info += '\n' self.info += '\n'
self.info += LINE self.info += LINE
@ -122,36 +121,33 @@ class TensorDetector():
# it should be updated in the saved_tensor_info as well # it should be updated in the saved_tensor_info as well
outdated = [x for x in self.saved_tensor_info.keys() if x in self.order] outdated = [x for x in self.saved_tensor_info.keys() if x in self.order]
minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected] minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected]
minus = outdated + minus minus = outdated + minus
if len(self.order) > 0: if len(self.order) > 0:
for tensor_id in self.order: for tensor_id in self.order:
self.info += template_format.format('+', self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]),
str(self.tensor_info[tensor_id][0]), str(self.tensor_info[tensor_id][1]),
str(self.tensor_info[tensor_id][1]), str(tuple(self.tensor_info[tensor_id][2])),
str(tuple(self.tensor_info[tensor_id][2])), str(self.tensor_info[tensor_id][3]),
str(self.tensor_info[tensor_id][3]), str(self.tensor_info[tensor_id][4]),
str(self.tensor_info[tensor_id][4]), str(self.tensor_info[tensor_id][5]))
str(self.tensor_info[tensor_id][5]))
self.info += '\n' self.info += '\n'
if len(self.order) > 0 and len(minus) > 0: if len(self.order) > 0 and len(minus) > 0:
self.info += '\n' self.info += '\n'
if len(minus) > 0: if len(minus) > 0:
for tensor_id in minus: for tensor_id in minus:
self.info += template_format.format('-', self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]),
str(self.saved_tensor_info[tensor_id][0]), str(self.saved_tensor_info[tensor_id][1]),
str(self.saved_tensor_info[tensor_id][1]), str(tuple(self.saved_tensor_info[tensor_id][2])),
str(tuple(self.saved_tensor_info[tensor_id][2])), str(self.saved_tensor_info[tensor_id][3]),
str(self.saved_tensor_info[tensor_id][3]), str(self.saved_tensor_info[tensor_id][4]),
str(self.saved_tensor_info[tensor_id][4]), str(self.saved_tensor_info[tensor_id][5]))
str(self.saved_tensor_info[tensor_id][5]))
self.info += '\n' self.info += '\n'
# deleted the updated tensor # deleted the updated tensor
self.saved_tensor_info.pop(tensor_id) self.saved_tensor_info.pop(tensor_id)
# trace where is the detect() # trace where is the detect()
locate_info = inspect.stack()[2] locate_info = inspect.stack()[2]
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno) locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
self.info += LINE self.info += LINE
self.info += f"Detect Location: {locate_msg}\n" self.info += f"Detect Location: {locate_msg}\n"
@ -167,8 +163,8 @@ class TensorDetector():
if self.log is not None: if self.log is not None:
with open(self.log + '.log', 'a') as f: with open(self.log + '.log', 'a') as f:
f.write(self.info) f.write(self.info)
def detect(self, include_cpu = False): def detect(self, include_cpu=False):
self.include_cpu = include_cpu self.include_cpu = include_cpu
self.collect_tensors_state() self.collect_tensors_state()
self.print_tensors_state() self.print_tensors_state()
@ -180,4 +176,4 @@ class TensorDetector():
def close(self): def close(self):
self.saved_tensor_info.clear() self.saved_tensor_info.clear()
self.module = None self.module = None

Loading…
Cancel
Save