import torch
from matplotlib import pyplot as plt
import numpy as np
from torch.utils import data
import torch
from torch import nn
from collections.abc import Iterator,Iterable
true_w = torch.tensor([7.0,8.0])
true_b = torch.tensor([5.0])
def generate_data(num_examples):
x = torch.normal(0,1,[num_examples,2])
y = torch.matmul(x,true_w)+true_b+torch.normal(0,0.01,[num_examples])
y = y.reshape([-1, 1])
print(y.shape)
return x,y
def load_data(data_array,batch_size,is_train=True):
train_data = data.TensorDataset(*data_array)
return data.DataLoader(train_data,batch_size,shuffle=is_train)
loss = nn.MSELoss()
net = nn.Sequential(nn.Linear(2,1))
trainer = torch.optim.SGD(net.parameters(),lr=0.03)
data_iter = load_data(generate_data(1000),5)
for x,y in data_iter:
l = loss(net(x),y)
l.backward()
trainer.step()
trainer.zero_grad()
print(net.state_dict())
|