1.原理
添加激活函数,避免线性回归的无效
2.代码实现
import torch
# import torch.nn.functional as F
# prepare dataset
x_data = torch.Tensor([[1.0], [2.0], [3.0], [4.0]])
y_data = torch.Tensor([[0], [0], [1], [1]])
# design model using class
class LogisticRegressionModel(torch.nn.Module):
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
model = LogisticRegressionModel()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# training cycle forward, backward, update
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[2.2]])
y_test = model(x_test)
print('y_test= ', y_test.data)
运行结果
0 0.8333730697631836
1 0.8216124773025513
2 0.8102343082427979
3 0.799232006072998
·············
·············
·············
997 0.3926210105419159
998 0.3925216794013977
999 0.3924223482608795
w = 0.8659290075302124
b = -1.701526165008545
y_test= tensor([[0.7102]])
|