Added tensor detector (#393)

* Added tensor detector

* Added the - states

* Allowed change include_cpu when detect()
pull/413/head
LuGY 2022-03-14 18:01:46 +08:00 committed by GitHub
parent 32296cf462
commit a9c27be42e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 361 additions and 1 deletions

View File

@ -10,7 +10,7 @@ from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient
from .memory import report_memory_usage
from .timer import MultiTimer, Timer
#from .tensor_detector import TensorDetector
from .tensor_detector import TensorDetector
__all__ = [
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',

View File

@ -0,0 +1 @@
from .tensor_detector import TensorDetector

View File

@ -0,0 +1,128 @@
# Tensor Detector
This tool supports you to detect tensors on both CPU and GPU. However, there will always be some strange tensors on CPU, including the rng state of PyTorch.
## Example
An example is worth than a thousand words.
The code below defines a simple MLP module, with which we will show you how to use the tool.
```python
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(64, 8),
nn.ReLU(),
nn.Linear(8, 32))
def forward(self, x):
return self.mlp(x)
```
And here is how to use the tool.
```python
from colossalai.utils import TensorDetector
# create random data
data = torch.rand(64, requires_grad=True).cuda()
data.retain_grad()
# create the module
model = MLP().cuda()
# create the detector
# by passing the model to the detector, it can distinguish module parameters from common tensors
detector = TensorDetector(include_cpu=False, module=model)
detector.detect()
out = model(data)
detector.detect()
loss = out.sum()
loss.backward()
detector.detect()
```
I have made some comments on the right of the output for your understanding.
Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memery Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly.
**The order of print is not equal to the order the tensor creates, but they are really close.**
```bash
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor cuda:0 (64,) True torch.float32 256 B # data
+ mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB
+ mlp.0.bias cuda:0 (8,) True torch.float32 32 B
+ mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB
+ mlp.2.bias cuda:0 (32,) True torch.float32 128 B
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 27
Totle GPU Memery Allocated on cuda:0 is 4.5 KB
------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor cuda:0 (8,) True torch.float32 32 B # activation
+ Tensor cuda:0 (32,) True torch.float32 128 B # output
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 30
Totle GPU Memery Allocated on cuda:0 is 5.5 KB
------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor cuda:0 () True torch.float32 4 B # loss
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 32
Totle GPU Memery Allocated on cuda:0 is 6.0 KB
------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor (with grad) cuda:0 (64,) True torch.float32 512 B # data with grad
+ mlp.0.weight (with grad) cuda:0 (8, 64) True torch.float32 4.0 KB # for use data.retain_grad()
+ mlp.0.bias (with grad) cuda:0 (8,) True torch.float32 64 B
+ mlp.2.weight (with grad) cuda:0 (32, 8) True torch.float32 2.0 KB
+ mlp.2.bias (with grad) cuda:0 (32,) True torch.float32 256 B
- mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB
- mlp.0.bias cuda:0 (8,) True torch.float32 32 B
- mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB
- mlp.2.bias cuda:0 (32,) True torch.float32 128 B
- Tensor cuda:0 (64,) True torch.float32 256 B
- Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 34
Totle GPU Memery Allocated on cuda:0 is 10.0 KB
------------------------------------------------------------------------------------------------------------
------------------------------------------------------------------------------------------------------------
Tensor device shape grad dtype Mem
------------------------------------------------------------------------------------------------------------
+ Tensor cuda:0 (64,) False torch.float32 256 B
+ Tensor cuda:0 (8, 64) False torch.float32 2.0 KB
+ Tensor cuda:0 (8,) False torch.float32 32 B
+ Tensor cuda:0 (32, 8) False torch.float32 1.0 KB
+ Tensor cuda:0 (32,) False torch.float32 128 B
------------------------------------------------------------------------------------------------------------
Detect Location: "test_tensor_detector.py" line 36
Totle GPU Memery Allocated on cuda:0 is 14.0 KB
------------------------------------------------------------------------------------------------------------
```
## Reference
This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py
and https://github.com/Oldpan/Pytorch-Memory-Utils

View File

@ -0,0 +1,190 @@
import gc
import inspect
import torch
import torch.nn as nn
from typing import Optional
from collections import defaultdict
LINE_WIDTH = 108
LINE = '-' * LINE_WIDTH + '\n'
class TensorDetector():
def __init__(self,
show_info: bool = True,
log: str = None,
include_cpu: bool = False,
module: Optional[nn.Module] = None
):
"""This class is an detector to detect tensor on different devices.
:param show_info: whether to print the info on screen, default True
:type show_info: bool
:param log: the file name to save the log
:type log: str
:param include_cpu: whether to detect tensor on cpu, default False
:type include_cpu: bool
:param module: when sending an `nn.Module` it, the detector can name the tensors detected better
:type module: Optional[nn.Module]
"""
self.show_info = show_info
self.log = log
self.include_cpu = include_cpu
self.tensor_info = defaultdict(list)
self.saved_tensor_info = defaultdict(list)
self.order = []
self.detected = []
self.devices = []
self.info = ""
self.module = module
if isinstance(module, nn.Module):
# if module is an instance of nn.Module, we can name the parameter with its real name
for name, param in module.named_parameters():
self.tensor_info[id(param)].append(name)
self.tensor_info[id(param)].append(param.device)
self.tensor_info[id(param)].append(param.shape)
self.tensor_info[id(param)].append(param.requires_grad)
self.tensor_info[id(param)].append(param.dtype)
self.tensor_info[id(param)].append(self.get_tensor_mem(param))
def get_tensor_mem(self, tensor):
# calculate the memory occupied by a tensor
memory_size = tensor.element_size() * tensor.storage().size()
if (tensor.is_leaf or tensor.retains_grad) and tensor.grad is not None:
grad_memory_size = tensor.grad.element_size() * tensor.grad.storage().size()
memory_size += grad_memory_size
return self.mem_format(memory_size)
def mem_format(self, real_memory_size):
# format the tensor memory into a reasonal magnitude
if real_memory_size >= 2 ** 30:
return str(real_memory_size / (2 ** 30)) + ' GB'
if real_memory_size >= 2 ** 20:
return str(real_memory_size / (2 ** 20)) + ' MB'
if real_memory_size >= 2 ** 10:
return str(real_memory_size / (2 ** 10)) + ' KB'
return str(real_memory_size) + ' B'
def collect_tensors_state(self):
for obj in gc.get_objects():
if torch.is_tensor(obj):
# skip cpu tensor when include_cpu is false and the tensor we have collected before
if (not self.include_cpu) and obj.device == torch.device('cpu'):
continue
self.detected.append(id(obj))
# skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch
if id(obj) not in self.tensor_info:
name = type(obj).__name__
# after backward, we want to update the records, to show you the change
if isinstance(self.module, nn.Module) and name == 'Parameter':
if obj.grad is not None:
# with grad attached
for par_name, param in self.module.named_parameters():
if param.requires_grad and param.grad.equal(obj.grad):
name = par_name + ' (with grad)'
else:
# with no grad attached
# there will be no new paramters created during running
# so it must be in saved_tensor_info
continue
# we can also marked common tensors as tensor(with grad)
if name == 'Tensor' and (obj.is_leaf or obj.retains_grad):
if obj.grad is not None:
name = name + ' (with grad)'
# in fact, common tensor have no grad
# unless you set retain_grad()
if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]:
continue
self.tensor_info[id(obj)].append(name)
self.tensor_info[id(obj)].append(obj.device)
self.tensor_info[id(obj)].append(obj.shape)
self.tensor_info[id(obj)].append(obj.requires_grad)
self.tensor_info[id(obj)].append(obj.dtype)
self.tensor_info[id(obj)].append(self.get_tensor_mem(obj))
# recorded the order we got the tensor
# by this we can guess the tensor easily
# it will record every tensor updated this turn
self.order.append(id(obj))
# recorded all different devices
if obj.device not in self.devices:
self.devices.append(obj.device)
def print_tensors_state(self):
template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}'
self.info += LINE
self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem')
self.info += '\n'
self.info += LINE
# if a tensor updates this turn, and was recorded before
# it should be updated in the saved_tensor_info as well
outdated = [x for x in self.saved_tensor_info.keys() if x in self.order]
minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected]
minus = outdated + minus
if len(self.order) > 0:
for tensor_id in self.order:
self.info += template_format.format('+',
str(self.tensor_info[tensor_id][0]),
str(self.tensor_info[tensor_id][1]),
str(tuple(self.tensor_info[tensor_id][2])),
str(self.tensor_info[tensor_id][3]),
str(self.tensor_info[tensor_id][4]),
str(self.tensor_info[tensor_id][5]))
self.info += '\n'
if len(self.order) > 0 and len(minus) > 0:
self.info += '\n'
if len(minus) > 0:
for tensor_id in minus:
self.info += template_format.format('-',
str(self.saved_tensor_info[tensor_id][0]),
str(self.saved_tensor_info[tensor_id][1]),
str(tuple(self.saved_tensor_info[tensor_id][2])),
str(self.saved_tensor_info[tensor_id][3]),
str(self.saved_tensor_info[tensor_id][4]),
str(self.saved_tensor_info[tensor_id][5]))
self.info += '\n'
# deleted the updated tensor
self.saved_tensor_info.pop(tensor_id)
# trace where is the detect()
locate_info = inspect.stack()[2]
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
self.info += LINE
self.info += f"Detect Location: {locate_msg}\n"
for device in self.devices:
if device == torch.device('cpu'):
continue
gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device))
self.info += f"Totle GPU Memery Allocated on {device} is {gpu_mem_alloc}\n"
self.info += LINE
self.info += '\n\n'
if self.show_info:
print(self.info)
if self.log is not None:
with open(self.log + '.log', 'a') as f:
f.write(self.info)
def detect(self, include_cpu = False):
self.include_cpu = include_cpu
self.collect_tensors_state()
self.print_tensors_state()
self.saved_tensor_info.update(self.tensor_info)
self.tensor_info.clear()
self.order = []
self.detected = []
self.info = ""
def close(self):
self.saved_tensor_info.clear()
self.module = None

View File

@ -0,0 +1,41 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from colossalai.utils import TensorDetector
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(64, 8),
nn.ReLU(),
nn.Linear(8, 32))
def forward(self, x):
return self.mlp(x)
def test_tensor_detect():
data = torch.rand(64, requires_grad=True).cuda()
data.retain_grad()
model = MLP().cuda()
detector = TensorDetector(log='test', include_cpu=False, module=model)
detector.detect()
out = model(data)
detector.detect()
loss = out.sum()
detector.detect()
loss.backward()
detector.detect()
model = MLP().cuda()
detector.detect()
detector.close()
torch.cuda.empty_cache()
if __name__ == '__main__':
test_tensor_detect()