从零开始实现:https://blog.csdn.net/tongjingqi_/article/details/122766296
import torch
import matplotlib.pyplot as plt
from torch import nn
from d2l import torch as d2l
d2l.use_svg_display()
batch_size=256
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)
net=nn.Sequential(nn.Flatten(),nn.Linear(784,10))
def init_weights(m):
if type(m)==nn.Linear:
nn.init.normal_(m.weight,std=0.01)
net.apply(init_weights)
loss=nn.CrossEntropyLoss(reduction='none')
trainer=torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs=10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)
plt.show()
|