发发库存大家新年快乐!
import torchvision
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
data = torchvision.datasets.CIFAR10('./dataset/', train=False, transform=transforms.ToTensor(), download=True)
data_lodaer = DataLoader(dataset=data, batch_size=64)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv2d = nn.Conv2d(in_channels=3, out_channels=6,
kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv2d(x)
return x
writter = SummaryWriter('./cifa10/')
model = Model()
step = 0
for data in data_lodaer:
imgs, targets = data
out = model(imgs)
writter.add_images('imgs', imgs, step)
writter.add_images('out', out.reshape((-1, 3, 32, 32)), step)
step += 1
writter.close()
|