mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/utils/tensor_detector/tensor_detector.py code style (#1566)
parent
0c4c9aa6e0
commit
c7d4932956
|
@ -5,18 +5,17 @@ import torch.nn as nn
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
LINE_WIDTH = 108
|
LINE_WIDTH = 108
|
||||||
LINE = '-' * LINE_WIDTH + '\n'
|
LINE = '-' * LINE_WIDTH + '\n'
|
||||||
|
|
||||||
|
|
||||||
class TensorDetector():
|
class TensorDetector():
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
show_info: bool = True,
|
show_info: bool = True,
|
||||||
log: str = None,
|
log: str = None,
|
||||||
include_cpu: bool = False,
|
include_cpu: bool = False,
|
||||||
module: Optional[nn.Module] = None
|
module: Optional[nn.Module] = None):
|
||||||
):
|
|
||||||
"""This class is a detector to detect tensor on different devices.
|
"""This class is a detector to detect tensor on different devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -125,8 +124,7 @@ class TensorDetector():
|
||||||
minus = outdated + minus
|
minus = outdated + minus
|
||||||
if len(self.order) > 0:
|
if len(self.order) > 0:
|
||||||
for tensor_id in self.order:
|
for tensor_id in self.order:
|
||||||
self.info += template_format.format('+',
|
self.info += template_format.format('+', str(self.tensor_info[tensor_id][0]),
|
||||||
str(self.tensor_info[tensor_id][0]),
|
|
||||||
str(self.tensor_info[tensor_id][1]),
|
str(self.tensor_info[tensor_id][1]),
|
||||||
str(tuple(self.tensor_info[tensor_id][2])),
|
str(tuple(self.tensor_info[tensor_id][2])),
|
||||||
str(self.tensor_info[tensor_id][3]),
|
str(self.tensor_info[tensor_id][3]),
|
||||||
|
@ -137,8 +135,7 @@ class TensorDetector():
|
||||||
self.info += '\n'
|
self.info += '\n'
|
||||||
if len(minus) > 0:
|
if len(minus) > 0:
|
||||||
for tensor_id in minus:
|
for tensor_id in minus:
|
||||||
self.info += template_format.format('-',
|
self.info += template_format.format('-', str(self.saved_tensor_info[tensor_id][0]),
|
||||||
str(self.saved_tensor_info[tensor_id][0]),
|
|
||||||
str(self.saved_tensor_info[tensor_id][1]),
|
str(self.saved_tensor_info[tensor_id][1]),
|
||||||
str(tuple(self.saved_tensor_info[tensor_id][2])),
|
str(tuple(self.saved_tensor_info[tensor_id][2])),
|
||||||
str(self.saved_tensor_info[tensor_id][3]),
|
str(self.saved_tensor_info[tensor_id][3]),
|
||||||
|
@ -148,7 +145,6 @@ class TensorDetector():
|
||||||
# deleted the updated tensor
|
# deleted the updated tensor
|
||||||
self.saved_tensor_info.pop(tensor_id)
|
self.saved_tensor_info.pop(tensor_id)
|
||||||
|
|
||||||
|
|
||||||
# trace where is the detect()
|
# trace where is the detect()
|
||||||
locate_info = inspect.stack()[2]
|
locate_info = inspect.stack()[2]
|
||||||
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
|
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
|
||||||
|
|
Loading…
Reference in New Issue