pytorch的显存机制,一直不明白,今天看了两篇文章,终于有些明白了,这两篇文章如下: https://zhuanlan.zhihu.com/p/424512257 https://blog.csdn.net/qq_43827595/article/details/115722953
根据第一篇知乎文章及其评论,我自己写了一套代码来记录各个步骤中显存的占用情况,以下是通过jupyter notebook中转化而来得到的markdown文件内容:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(1024,1024, bias=False)
self.linear2 = nn.Linear(1024, 1, bias=False)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
out = sum(x)
return out
inputs = torch.tensor([[1.0]*1024]*1024).cuda()
print(torch.cuda.memory_allocated())
4194304
net = Net().cuda()
print(torch.cuda.memory_allocated())
8392704
out = net(inputs)
print(torch.cuda.memory_allocated())
12587520
out.backward()
print(torch.cuda.memory_allocated())
12591616
from torch.optim import Adam
optimizer = Adam(net.parameters(), lr=1e-3)
print(torch.cuda.max_memory_allocated())
16790528
optimizer.step()
print(torch.cuda.max_memory_allocated())
29377024
16790528 - 12591616
4198912
|