diff --git a/colossalai/utils/tensor_detector/tensor_detector.py b/colossalai/utils/tensor_detector/tensor_detector.py index 616eb3b9e..a8186f768 100644 --- a/colossalai/utils/tensor_detector/tensor_detector.py +++ b/colossalai/utils/tensor_detector/tensor_detector.py @@ -5,18 +5,17 @@ import torch.nn as nn from typing import Optional from collections import defaultdict - LINE_WIDTH = 108 LINE = '-' * LINE_WIDTH + '\n' class TensorDetector(): + def __init__(self, show_info: bool = True, log: str = None, include_cpu: bool = False, - module: Optional[nn.Module] = None - ): + module: Optional[nn.Module] = None): """This class is a detector to detect tensor on different devices. Args: @@ -28,7 +27,7 @@ class TensorDetector(): """ self.show_info = show_info self.log = log - self.include_cpu = include_cpu + self.include_cpu = include_cpu self.tensor_info = defaultdict(list) self.saved_tensor_info = defaultdict(list) self.order = [] @@ -57,13 +56,13 @@ class TensorDetector(): def mem_format(self, real_memory_size): # format the tensor memory into a reasonal magnitude - if real_memory_size >= 2 ** 30: - return str(real_memory_size / (2 ** 30)) + ' GB' - if real_memory_size >= 2 ** 20: - return str(real_memory_size / (2 ** 20)) + ' MB' - if real_memory_size >= 2 ** 10: - return str(real_memory_size / (2 ** 10)) + ' KB' - return str(real_memory_size) + ' B' + if real_memory_size >= 2**30: + return str(real_memory_size / (2**30)) + ' GB' + if real_memory_size >= 2**20: + return str(real_memory_size / (2**20)) + ' MB' + if real_memory_size >= 2**10: + return str(real_memory_size / (2**10)) + ' KB' + return str(real_memory_size) + ' B' def collect_tensors_state(self): for obj in gc.get_objects(): @@ -74,11 +73,11 @@ class TensorDetector(): self.detected.append(id(obj)) # skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch if id(obj) not in self.tensor_info: - + name = type(obj).__name__ # after backward, we want to update the records, to show you the change if isinstance(self.module, nn.Module) and name == 'Parameter': - if obj.grad is not None: + if obj.grad is not None: # with grad attached for par_name, param in self.module.named_parameters(): if param.requires_grad and param.grad.equal(obj.grad): @@ -88,7 +87,7 @@ class TensorDetector(): # there will be no new paramters created during running # so it must be in saved_tensor_info continue - # we can also marked common tensors as tensor(with grad) + # we can also marked common tensors as tensor(with grad) if name == 'Tensor' and (obj.is_leaf or obj.retains_grad): if obj.grad is not None: name = name + ' (with grad)' @@ -104,7 +103,7 @@ class TensorDetector(): self.tensor_info[id(obj)].append(obj.dtype) self.tensor_info[id(obj)].append(self.get_tensor_mem(obj)) # recorded the order we got the tensor - # by this we can guess the tensor easily + # by this we can guess the tensor easily # it will record every tensor updated this turn self.order.append(id(obj)) # recorded all different devices @@ -114,7 +113,7 @@ class TensorDetector(): def print_tensors_state(self): template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}' self.info += LINE - self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') + self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem') self.info += '\n' self.info += LINE @@ -122,36 +121,33 @@ class TensorDetector(): # it should be updated in the saved_tensor_info as well outdated = [x for x in self.saved_tensor_info.keys() if x in self.order] minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected] - minus = outdated + minus + minus = outdated + minus if len(self.order) > 0: for tensor_id in self.order: - self.info += template_format.format('+', - str(self.tensor_info[tensor_id][0]), - str(self.tensor_info[tensor_id][1]), - str(tuple(self.tensor_info[tensor_id][2])), - str(self.tensor_info[tensor_id][3]), - str(self.tensor_info[tensor_id][4]), - str(self.tensor_info[tensor_id][5])) + self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]), + str(self.tensor_info[tensor_id][1]), + str(tuple(self.tensor_info[tensor_id][2])), + str(self.tensor_info[tensor_id][3]), + str(self.tensor_info[tensor_id][4]), + str(self.tensor_info[tensor_id][5])) self.info += '\n' if len(self.order) > 0 and len(minus) > 0: self.info += '\n' if len(minus) > 0: for tensor_id in minus: - self.info += template_format.format('-', - str(self.saved_tensor_info[tensor_id][0]), - str(self.saved_tensor_info[tensor_id][1]), - str(tuple(self.saved_tensor_info[tensor_id][2])), - str(self.saved_tensor_info[tensor_id][3]), - str(self.saved_tensor_info[tensor_id][4]), - str(self.saved_tensor_info[tensor_id][5])) + self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]), + str(self.saved_tensor_info[tensor_id][1]), + str(tuple(self.saved_tensor_info[tensor_id][2])), + str(self.saved_tensor_info[tensor_id][3]), + str(self.saved_tensor_info[tensor_id][4]), + str(self.saved_tensor_info[tensor_id][5])) self.info += '\n' # deleted the updated tensor self.saved_tensor_info.pop(tensor_id) - # trace where is the detect() locate_info = inspect.stack()[2] - locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno) + locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno) self.info += LINE self.info += f"Detect Location: {locate_msg}\n" @@ -167,8 +163,8 @@ class TensorDetector(): if self.log is not None: with open(self.log + '.log', 'a') as f: f.write(self.info) - - def detect(self, include_cpu = False): + + def detect(self, include_cpu=False): self.include_cpu = include_cpu self.collect_tensors_state() self.print_tensors_state() @@ -180,4 +176,4 @@ class TensorDetector(): def close(self): self.saved_tensor_info.clear() - self.module = None \ No newline at end of file + self.module = None