mirror of https://github.com/hpcaitech/ColossalAI
Added tensor detector (#393)
* Added tensor detector * Added the - states * Allowed change include_cpu when detect()pull/413/head
parent
32296cf462
commit
a9c27be42e
|
@ -10,7 +10,7 @@ from .data_sampler import DataParallelSampler, get_dataloader
|
|||
from .gradient_accumulation import accumulate_gradient
|
||||
from .memory import report_memory_usage
|
||||
from .timer import MultiTimer, Timer
|
||||
#from .tensor_detector import TensorDetector
|
||||
from .tensor_detector import TensorDetector
|
||||
|
||||
__all__ = [
|
||||
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .tensor_detector import TensorDetector
|
|
@ -0,0 +1,128 @@
|
|||
# Tensor Detector
|
||||
|
||||
This tool supports you to detect tensors on both CPU and GPU. However, there will always be some strange tensors on CPU, including the rng state of PyTorch.
|
||||
|
||||
## Example
|
||||
|
||||
An example is worth than a thousand words.
|
||||
|
||||
The code below defines a simple MLP module, with which we will show you how to use the tool.
|
||||
|
||||
```python
|
||||
class MLP(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(nn.Linear(64, 8),
|
||||
nn.ReLU(),
|
||||
nn.Linear(8, 32))
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
```
|
||||
|
||||
And here is how to use the tool.
|
||||
|
||||
```python
|
||||
from colossalai.utils import TensorDetector
|
||||
|
||||
# create random data
|
||||
data = torch.rand(64, requires_grad=True).cuda()
|
||||
data.retain_grad()
|
||||
# create the module
|
||||
model = MLP().cuda()
|
||||
# create the detector
|
||||
# by passing the model to the detector, it can distinguish module parameters from common tensors
|
||||
detector = TensorDetector(include_cpu=False, module=model)
|
||||
detector.detect()
|
||||
|
||||
out = model(data)
|
||||
|
||||
detector.detect()
|
||||
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
detector.detect()
|
||||
```
|
||||
|
||||
I have made some comments on the right of the output for your understanding.
|
||||
|
||||
Note that the total `Mem` of all the tensors and parameters is not equal to `Total GPU Memery Allocated`. PyTorch's memory management is really complicated, and for models of a large scale, it's impossible to figure out clearly.
|
||||
|
||||
**The order of print is not equal to the order the tensor creates, but they are really close.**
|
||||
|
||||
```bash
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Tensor device shape grad dtype Mem
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
+ Tensor cuda:0 (64,) True torch.float32 256 B # data
|
||||
+ mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB
|
||||
+ mlp.0.bias cuda:0 (8,) True torch.float32 32 B
|
||||
+ mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB
|
||||
+ mlp.2.bias cuda:0 (32,) True torch.float32 128 B
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Detect Location: "test_tensor_detector.py" line 27
|
||||
Totle GPU Memery Allocated on cuda:0 is 4.5 KB
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Tensor device shape grad dtype Mem
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
+ Tensor cuda:0 (8,) True torch.float32 32 B # activation
|
||||
+ Tensor cuda:0 (32,) True torch.float32 128 B # output
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Detect Location: "test_tensor_detector.py" line 30
|
||||
Totle GPU Memery Allocated on cuda:0 is 5.5 KB
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Tensor device shape grad dtype Mem
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
+ Tensor cuda:0 () True torch.float32 4 B # loss
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Detect Location: "test_tensor_detector.py" line 32
|
||||
Totle GPU Memery Allocated on cuda:0 is 6.0 KB
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Tensor device shape grad dtype Mem
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
+ Tensor (with grad) cuda:0 (64,) True torch.float32 512 B # data with grad
|
||||
+ mlp.0.weight (with grad) cuda:0 (8, 64) True torch.float32 4.0 KB # for use data.retain_grad()
|
||||
+ mlp.0.bias (with grad) cuda:0 (8,) True torch.float32 64 B
|
||||
+ mlp.2.weight (with grad) cuda:0 (32, 8) True torch.float32 2.0 KB
|
||||
+ mlp.2.bias (with grad) cuda:0 (32,) True torch.float32 256 B
|
||||
|
||||
- mlp.0.weight cuda:0 (8, 64) True torch.float32 2.0 KB
|
||||
- mlp.0.bias cuda:0 (8,) True torch.float32 32 B
|
||||
- mlp.2.weight cuda:0 (32, 8) True torch.float32 1.0 KB
|
||||
- mlp.2.bias cuda:0 (32,) True torch.float32 128 B
|
||||
- Tensor cuda:0 (64,) True torch.float32 256 B
|
||||
- Tensor cuda:0 (8,) True torch.float32 32 B # deleted activation
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Detect Location: "test_tensor_detector.py" line 34
|
||||
Totle GPU Memery Allocated on cuda:0 is 10.0 KB
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Tensor device shape grad dtype Mem
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
+ Tensor cuda:0 (64,) False torch.float32 256 B
|
||||
+ Tensor cuda:0 (8, 64) False torch.float32 2.0 KB
|
||||
+ Tensor cuda:0 (8,) False torch.float32 32 B
|
||||
+ Tensor cuda:0 (32, 8) False torch.float32 1.0 KB
|
||||
+ Tensor cuda:0 (32,) False torch.float32 128 B
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
Detect Location: "test_tensor_detector.py" line 36
|
||||
Totle GPU Memery Allocated on cuda:0 is 14.0 KB
|
||||
------------------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py
|
||||
and https://github.com/Oldpan/Pytorch-Memory-Utils
|
||||
|
|
@ -0,0 +1,190 @@
|
|||
import gc
|
||||
import inspect
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
LINE_WIDTH = 108
|
||||
LINE = '-' * LINE_WIDTH + '\n'
|
||||
|
||||
class TensorDetector():
|
||||
def __init__(self,
|
||||
show_info: bool = True,
|
||||
log: str = None,
|
||||
include_cpu: bool = False,
|
||||
module: Optional[nn.Module] = None
|
||||
):
|
||||
"""This class is an detector to detect tensor on different devices.
|
||||
|
||||
:param show_info: whether to print the info on screen, default True
|
||||
:type show_info: bool
|
||||
:param log: the file name to save the log
|
||||
:type log: str
|
||||
:param include_cpu: whether to detect tensor on cpu, default False
|
||||
:type include_cpu: bool
|
||||
:param module: when sending an `nn.Module` it, the detector can name the tensors detected better
|
||||
:type module: Optional[nn.Module]
|
||||
|
||||
"""
|
||||
self.show_info = show_info
|
||||
self.log = log
|
||||
self.include_cpu = include_cpu
|
||||
self.tensor_info = defaultdict(list)
|
||||
self.saved_tensor_info = defaultdict(list)
|
||||
self.order = []
|
||||
self.detected = []
|
||||
self.devices = []
|
||||
self.info = ""
|
||||
|
||||
self.module = module
|
||||
if isinstance(module, nn.Module):
|
||||
# if module is an instance of nn.Module, we can name the parameter with its real name
|
||||
for name, param in module.named_parameters():
|
||||
self.tensor_info[id(param)].append(name)
|
||||
self.tensor_info[id(param)].append(param.device)
|
||||
self.tensor_info[id(param)].append(param.shape)
|
||||
self.tensor_info[id(param)].append(param.requires_grad)
|
||||
self.tensor_info[id(param)].append(param.dtype)
|
||||
self.tensor_info[id(param)].append(self.get_tensor_mem(param))
|
||||
|
||||
|
||||
def get_tensor_mem(self, tensor):
|
||||
# calculate the memory occupied by a tensor
|
||||
memory_size = tensor.element_size() * tensor.storage().size()
|
||||
if (tensor.is_leaf or tensor.retains_grad) and tensor.grad is not None:
|
||||
grad_memory_size = tensor.grad.element_size() * tensor.grad.storage().size()
|
||||
memory_size += grad_memory_size
|
||||
return self.mem_format(memory_size)
|
||||
|
||||
|
||||
def mem_format(self, real_memory_size):
|
||||
# format the tensor memory into a reasonal magnitude
|
||||
if real_memory_size >= 2 ** 30:
|
||||
return str(real_memory_size / (2 ** 30)) + ' GB'
|
||||
if real_memory_size >= 2 ** 20:
|
||||
return str(real_memory_size / (2 ** 20)) + ' MB'
|
||||
if real_memory_size >= 2 ** 10:
|
||||
return str(real_memory_size / (2 ** 10)) + ' KB'
|
||||
return str(real_memory_size) + ' B'
|
||||
|
||||
|
||||
def collect_tensors_state(self):
|
||||
for obj in gc.get_objects():
|
||||
if torch.is_tensor(obj):
|
||||
# skip cpu tensor when include_cpu is false and the tensor we have collected before
|
||||
if (not self.include_cpu) and obj.device == torch.device('cpu'):
|
||||
continue
|
||||
self.detected.append(id(obj))
|
||||
# skip paramters we had added in __init__ when module is an instance of nn.Module for the first epoch
|
||||
if id(obj) not in self.tensor_info:
|
||||
|
||||
name = type(obj).__name__
|
||||
# after backward, we want to update the records, to show you the change
|
||||
if isinstance(self.module, nn.Module) and name == 'Parameter':
|
||||
if obj.grad is not None:
|
||||
# with grad attached
|
||||
for par_name, param in self.module.named_parameters():
|
||||
if param.requires_grad and param.grad.equal(obj.grad):
|
||||
name = par_name + ' (with grad)'
|
||||
else:
|
||||
# with no grad attached
|
||||
# there will be no new paramters created during running
|
||||
# so it must be in saved_tensor_info
|
||||
continue
|
||||
# we can also marked common tensors as tensor(with grad)
|
||||
if name == 'Tensor' and (obj.is_leaf or obj.retains_grad):
|
||||
if obj.grad is not None:
|
||||
name = name + ' (with grad)'
|
||||
# in fact, common tensor have no grad
|
||||
# unless you set retain_grad()
|
||||
if id(obj) in self.saved_tensor_info.keys() and name == self.saved_tensor_info[id(obj)][0]:
|
||||
continue
|
||||
|
||||
self.tensor_info[id(obj)].append(name)
|
||||
self.tensor_info[id(obj)].append(obj.device)
|
||||
self.tensor_info[id(obj)].append(obj.shape)
|
||||
self.tensor_info[id(obj)].append(obj.requires_grad)
|
||||
self.tensor_info[id(obj)].append(obj.dtype)
|
||||
self.tensor_info[id(obj)].append(self.get_tensor_mem(obj))
|
||||
# recorded the order we got the tensor
|
||||
# by this we can guess the tensor easily
|
||||
# it will record every tensor updated this turn
|
||||
self.order.append(id(obj))
|
||||
# recorded all different devices
|
||||
if obj.device not in self.devices:
|
||||
self.devices.append(obj.device)
|
||||
|
||||
|
||||
def print_tensors_state(self):
|
||||
template_format = '{:3s}{:<30s}{:>10s}{:>20s}{:>10s}{:>20s}{:>15s}'
|
||||
self.info += LINE
|
||||
self.info += template_format.format(' ', 'Tensor', 'device', 'shape', 'grad', 'dtype', 'Mem')
|
||||
self.info += '\n'
|
||||
self.info += LINE
|
||||
|
||||
# if a tensor updates this turn, and was recorded before
|
||||
# it should be updated in the saved_tensor_info as well
|
||||
outdated = [x for x in self.saved_tensor_info.keys() if x in self.order]
|
||||
minus = [x for x in self.saved_tensor_info.keys() if x not in self.detected]
|
||||
minus = outdated + minus
|
||||
if len(self.order) > 0:
|
||||
for tensor_id in self.order:
|
||||
self.info += template_format.format('+',
|
||||
str(self.tensor_info[tensor_id][0]),
|
||||
str(self.tensor_info[tensor_id][1]),
|
||||
str(tuple(self.tensor_info[tensor_id][2])),
|
||||
str(self.tensor_info[tensor_id][3]),
|
||||
str(self.tensor_info[tensor_id][4]),
|
||||
str(self.tensor_info[tensor_id][5]))
|
||||
self.info += '\n'
|
||||
if len(self.order) > 0 and len(minus) > 0:
|
||||
self.info += '\n'
|
||||
if len(minus) > 0:
|
||||
for tensor_id in minus:
|
||||
self.info += template_format.format('-',
|
||||
str(self.saved_tensor_info[tensor_id][0]),
|
||||
str(self.saved_tensor_info[tensor_id][1]),
|
||||
str(tuple(self.saved_tensor_info[tensor_id][2])),
|
||||
str(self.saved_tensor_info[tensor_id][3]),
|
||||
str(self.saved_tensor_info[tensor_id][4]),
|
||||
str(self.saved_tensor_info[tensor_id][5]))
|
||||
self.info += '\n'
|
||||
# deleted the updated tensor
|
||||
self.saved_tensor_info.pop(tensor_id)
|
||||
|
||||
|
||||
# trace where is the detect()
|
||||
locate_info = inspect.stack()[2]
|
||||
locate_msg = '"' + locate_info.filename + '" line ' + str(locate_info.lineno)
|
||||
|
||||
self.info += LINE
|
||||
self.info += f"Detect Location: {locate_msg}\n"
|
||||
for device in self.devices:
|
||||
if device == torch.device('cpu'):
|
||||
continue
|
||||
gpu_mem_alloc = self.mem_format(torch.cuda.memory_allocated(device))
|
||||
self.info += f"Totle GPU Memery Allocated on {device} is {gpu_mem_alloc}\n"
|
||||
self.info += LINE
|
||||
self.info += '\n\n'
|
||||
if self.show_info:
|
||||
print(self.info)
|
||||
if self.log is not None:
|
||||
with open(self.log + '.log', 'a') as f:
|
||||
f.write(self.info)
|
||||
|
||||
|
||||
def detect(self, include_cpu = False):
|
||||
self.include_cpu = include_cpu
|
||||
self.collect_tensors_state()
|
||||
self.print_tensors_state()
|
||||
self.saved_tensor_info.update(self.tensor_info)
|
||||
self.tensor_info.clear()
|
||||
self.order = []
|
||||
self.detected = []
|
||||
self.info = ""
|
||||
|
||||
def close(self):
|
||||
self.saved_tensor_info.clear()
|
||||
self.module = None
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.utils import TensorDetector
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(nn.Linear(64, 8),
|
||||
nn.ReLU(),
|
||||
nn.Linear(8, 32))
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
def test_tensor_detect():
|
||||
|
||||
data = torch.rand(64, requires_grad=True).cuda()
|
||||
data.retain_grad()
|
||||
model = MLP().cuda()
|
||||
|
||||
detector = TensorDetector(log='test', include_cpu=False, module=model)
|
||||
|
||||
detector.detect()
|
||||
out = model(data)
|
||||
|
||||
detector.detect()
|
||||
loss = out.sum()
|
||||
detector.detect()
|
||||
loss.backward()
|
||||
detector.detect()
|
||||
model = MLP().cuda()
|
||||
detector.detect()
|
||||
detector.close()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_tensor_detect()
|
Loading…
Reference in New Issue