[hotfix] rm test_tensor_detector.py (#413)

pull/395/head^2
Jiarui Fang 2022-03-14 21:39:48 +08:00 committed by GitHub
parent 370f567e7d
commit a37bf1bc42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 0 additions and 41 deletions

View File

@ -1,41 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from colossalai.utils import TensorDetector
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(64, 8),
nn.ReLU(),
nn.Linear(8, 32))
def forward(self, x):
return self.mlp(x)
def test_tensor_detect():
data = torch.rand(64, requires_grad=True).cuda()
data.retain_grad()
model = MLP().cuda()
detector = TensorDetector(log='test', include_cpu=False, module=model)
detector.detect()
out = model(data)
detector.detect()
loss = out.sum()
detector.detect()
loss.backward()
detector.detect()
model = MLP().cuda()
detector.detect()
detector.close()
torch.cuda.empty_cache()
if __name__ == '__main__':
test_tensor_detect()