IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 【吴恩达深度学习】Convolutional Neural Networks: Application(PyTorch) -> 正文阅读

[Python知识库]【吴恩达深度学习】Convolutional Neural Networks: Application(PyTorch)

原文TensorFlow版

导入所需要的包

import torch
from torch import nn
from torch import optim
from cnn_utils import *

Flatten类

由于早期的PyTorch没有提供nn.Faltten 类,所以这里需要手写一个

class Flatten(nn.Module):

    def __init__(self, start_dim=1, end_dim=-1):
        super(Flatten, self).__init__()
        self.start_dim = start_dim
        self.end_dim = end_dim

    def forward(self, input):
        return input.flatten(self.start_dim, self.end_dim)

构建网络

网路结构: CONV2D -> RELU -> MAXPOOL -> CONV2D -> RELU -> MAXPOOL -> FLATTEN -> FULLYCONNECTED
网络细节:
In detail, we will use the following parameters for all the steps:
- Conv2D: stride 1, padding is “SAME”
- ReLU
- Max pool: Use an 8 by 8 filter size and an 8 by 8 stride, padding is “SAME”
- Conv2D: stride 1, padding is “SAME”
- ReLU
- Max pool: Use a 4 by 4 filter size and a 4 by 4 stride, padding is “SAME”
- Flatten the previous output.
- FULLYCONNECTED (FC) layer: Apply a fully connected layer without an non-linear activation function. Do not call the softmax here. This will result in 6 neurons in the output layer, which then get passed later to a softmax.

由于早期的PyTorch不提供 padding = “SAME”
,而且本网络中需要使用单边padding,所以需要手动加入padding层,具体实现细节如下。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.ZeroPad2d([1, 2, 1, 2]),
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=4, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=8, stride=8),
            nn.ZeroPad2d([0, 1, 0, 1]),
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=2, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=4, stride=4),
            Flatten(),
            nn.Linear(64,6)
        )
        self._init_parameters()

    def forward(self, x):
        x = self.net(x)
        return x

    def _init_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight.data)

Model

def model(X_train, Y_train, X_test, Y_test, learning_rate=0.009,
          num_epochs=100, minibatch_size=64, print_cost=True):

    seed = 3  # to keep results consistent (numpy seed)
    (m,n_C0, n_H0, n_W0) = X_train.shape
    n_y = Y_train.shape[1]
    costs = []  # To keep track of the cost

    net = Net()
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    net.train()
    # Do the training loop
    for epoch in range(num_epochs):

        minibatch_cost = 0.
        num_minibatches = int(m / minibatch_size)  # number of minibatches of size minibatch_size in the train set
        seed = seed + 1
        minibatches = random_mini_batches(X_train, Y_train, minibatch_size, seed)

        for minibatch in minibatches:
            # Select a minibatch
            (minibatch_X, minibatch_Y) = minibatch
            minibatch_X = torch.tensor(minibatch_X).float()
            minibatch_Y = torch.tensor(minibatch_Y).long().squeeze()

            optimizer.zero_grad()
            output = net(minibatch_X)
            temp_cost = criterion(output, minibatch_Y)
            temp_cost.backward()
            optimizer.step()

            minibatch_cost += temp_cost.item() / num_minibatches

        # Print the cost every epoch
        if print_cost == True and epoch % 5 == 0:
            print("Cost after epoch %i: %f" % (epoch, minibatch_cost))
        if print_cost == True and epoch % 1 == 0:
            costs.append(minibatch_cost)

    # plot the cost
    plt.plot(np.squeeze(costs))
    plt.ylabel('cost')
    plt.xlabel('iterations (per tens)')
    plt.title("Learning rate =" + str(learning_rate))
    plt.show()

    net.eval()
    with torch.no_grad():
        X = torch.tensor(X_train).float()
        Y = torch.tensor(Y_train).long().squeeze()
        output = net(X)
        output = torch.argmax(output, dim=1)
        correct_prediction = output == Y
        train_accuracy = torch.sum(correct_prediction).float() / X.shape[0]
        train_accuracy = train_accuracy.item()
        print("Train Accuracy:", train_accuracy)

        X = torch.tensor(X_test).float()
        Y = torch.tensor(Y_test).long().squeeze()
        output = net(X)
        output = torch.argmax(output, dim=1)
        correct_prediction = output == Y
        test_accuracy = torch.sum(correct_prediction).float() / X.shape[0]
        test_accuracy = test_accuracy.item()
        print("Test Accuracy:", test_accuracy)

    return train_accuracy, test_accuracy

主函数,用于加载数据和对数据预处理


if __name__ == '__main__':

    # Loading the data (signs)
    X_train_orig, Y_train_orig, X_test_orig, Y_test_orig, classes = load_dataset()

    X_train = X_train_orig / 255.
    X_test = X_test_orig / 255.
    Y_train = Y_train_orig.T
    Y_test = Y_test_orig.T
    X_train = np.transpose(X_train, [0, 3, 1, 2])
    X_test = np.transpose(X_test, [0, 3, 1, 2])

    print("number of training examples = " + str(X_train.shape[0]))
    print("number of test examples = " + str(X_test.shape[0]))
    print("X_train shape: " + str(X_train.shape))
    print("Y_train shape: " + str(Y_train.shape))
    print("X_test shape: " + str(X_test.shape))
    print("Y_test shape: " + str(Y_test.shape))

    model(X_train, Y_train, X_test, Y_test)

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-03-12 17:27:53  更:2022-03-12 17:28:50 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 18:04:33-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计