代码
import torch
from torch import nn
from d2l import torch as d2l
net = nn.Sequential(
nn.Flatten(),
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,10))
def init_weights(m):
if type(m) == nn.linear:
nn.init.normal_(m.weight,std=0.01)
net.apply(net.weight)
batch_size,lr,num_epochs = 256,0.1,10
loss = nn.CrossEntropyLoss()
trainer = torch.optim.SGD(net.parameters(),lr=lr)
train_iter,test_iter = d2l.load_data_fishion_mnist(batch_size)
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)
2. 结果
3. 模块备注
3.1 nn.Flatten()
- 作用
将原来的数据打平为 1 x N 维度的张量,一般在图像处理中放在前面
3.2 nn.Linear()
3.3 nn.ReLU()
3.4 nn.init.normal_()
|