IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 使用Numpy实现卷积的前向传播 -> 正文阅读

[Python知识库]使用Numpy实现卷积的前向传播

import numpy as np
import math


class Conv2D(object):
    def __init__(self, shape, output_channels, ksize=3, stride=1, method='VALID'):
        self.input_shape = shape
        self.output_channels = output_channels
        self.input_channels = shape[-1]
        self.batchsize = shape[0]
        self.stride = stride
        self.ksize = ksize
        self.method = method
        weights_scale = math.sqrt(ksize * ksize * self.input_channels / 2)
        self.weights = np.random.standard_normal(
            (ksize, ksize, self.input_channels, self.output_channels)) // weights_scale
        self.bias = np.random.standard_normal(self.output_channels) // weights_scale
        if method == 'VALID':
            self.eta = np.zeros((shape[0], (shape[1] - ksize) // self.stride + 1, (shape[1] - ksize) // self.stride + 1,
                                 self.output_channels))
        if method == 'SAME':
            self.eta = np.zeros((shape[0], shape[1] // self.stride, shape[2] // self.stride, self.output_channels))
        self.w_gradient = np.zeros(self.weights.shape)
        self.b_gradient = np.zeros(self.bias.shape)
        self.output_shape = self.eta.shape

    def forward(self, x):
        col_weights = self.weights.reshape([-1, self.output_channels])
        if self.method == 'SAME':
            x = np.pad(x, ((0, 0), (self.ksize // 2, self.ksize // 2), (self.ksize // 2, self.ksize // 2), (0, 0)),
                       'constant', constant_values=0)
        self.col_image = []
        conv_out = np.zeros(self.eta.shape)
        for i in range(self.batchsize):
            img_i = x[i][np.newaxis, ...]
            self.col_image_i = self.im2col(img_i, self.ksize, self.stride)
            print(col_weights.shape)
            conv_out[i] = np.reshape(np.dot(self.col_image_i, col_weights) + self.bias, self.eta[0].shape)
            self.col_image.append(self.col_image_i)
        return conv_out 

    def im2col(self, image, k_size, stride):
        image_col = []
        for i in range(0, image.shape[1] - k_size + 1, stride):
            for j in range(0, image.shape[2] - k_size + 1,
                           stride): 
                col = image[:, i:i + k_size, j:j + k_size, :].reshape([-1])
                image_col.append(col)
        image_col = np.array(image_col)
        print(image_col.shape)
        return image_col


if __name__ == '__main__':
    conv2d = Conv2D([5, 10, 10, 3], 32, 3, 1, 'VALID')
    input_data = np.random.standard_normal((5, 10, 10, 3))
    print("input:", input_data.shape)
    conv_out = conv2d.forward(input_data)
    print(conv_out.shape)

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2021-08-11 12:18:43  更:2021-08-11 12:21:45 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年5日历 -2024/5/17 12:55:46-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码