IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 基于深度强化学习的绘画智能体 代码分析(二) -> 正文阅读

[人工智能]基于深度强化学习的绘画智能体 代码分析(二)

基于深度强化学习的绘画智能体 代码分析(二)

Github源码链接

critic.py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.weight_norm as weightNorm

from torch.autograd import Variable //Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现
import sys
def conv3x3(in_planes, out_planes, stride=1):  //定义3*3的二维卷积模板
    return weightNorm(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") //这个device的用处是作为Tensor或者Model被分配到的位置。
coord = torch.zeros([1, 2, 64, 64]) //返回一个形状为[1, 2, 64, 64]的矩阵,里面的每一个值都是0的tensor
for i in range(64):
    for j in range(64):
        coord[0, 0, i, j] = i / 63. //目的是将矩阵的维度长度平均到1??
        coord[0, 1, i, j] = j / 63.
        coord = coord.to(device) //这行代码的意思是将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行。      
class TReLU(nn.Module):
    def __init__(self)://初始化
        super(TReLU, self).__init__()//就是对继承自父类nn.Module的属性进行初始化。而且是用nn.Module的初始化方法来初始化继承的属性。
        self.alpha = nn.Parameter(torch.FloatTensor(1), requires_grad=True) 
        
//torch.nn.Parameter()函数:含义是将一个固定不可训练的tensor转换成可以训练的类型parameter,并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。
//比如cnn输出4个东西,你又不想concate到到一起,你想用权重加法,权重又不想自己设定,想让网络自己学requires_grad=True这个很重要
        self.alpha.data.fill_(0)
def forward(self, x): //损失函数
        x = F.relu(x - self.alpha) + self.alpha
        return x
def cfg(depth):
    depth_lst = [18, 34, 50, 101, 152]
    assert (depth in depth_lst), "Error : Resnet depth should be either 18, 34, 50, 101, 152" //assert检查条件,不符合就终止程序,检测了Resnet 的层数?
    cf_dict = {
        '18': (BasicBlock, [2,2,2,2]),
        '34': (BasicBlock, [3,4,6,3]),
        '50': (Bottleneck, [3,4,6,3]), //Bottleneck是对于更深的网络,提出了另一种残差基础块
        '101':(Bottleneck, [3,4,23,3]),
        '152':(Bottleneck, [3,8,36,3]),
    }

    return cf_dict[str(depth)]
class BasicBlock(nn.Module): //定义BasicBlock内容
    expansion = 1 //expansion是BasicBlock和Bottleneck的核心区别之一
    

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride) //两个3x3的卷积
        self.conv2 = conv3x3(planes, planes)

        self.shortcut = nn.Sequential()//输入和输出维度匹配的情况
        

//输入和输出维度不匹配的情况(需要借助conv+bn将输入尺寸降低)
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                weightNorm(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True)),
            )
        self.relu_1 = TReLU()
        self.relu_2 = TReLU()

    def forward(self, x):
        out = self.relu_1(self.conv1(x))
        out = self.conv2(out)
        out += self.shortcut(x)
        out = self.relu_2(out)

        return out
class Bottleneck(nn.Module): //定义Bottleneck内容
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        
// weightNorm:对参数的规范化
        self.conv1 = weightNorm(nn.Conv2d(in_planes, planes, kernel_size=1, bias=True))
        self.conv2 = weightNorm(nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True))
        self.conv3 = weightNorm(nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=True))
        self.relu_1 = TReLU()
        self.relu_2 = TReLU()
        self.relu_3 = TReLU()

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                weightNorm(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True)),
            )

    def forward(self, x):
        out = self.relu_1(self.conv1(x))
        out = self.relu_2(self.conv2(out))
        out = self.conv3(out)
        out += self.shortcut(x)
        out = self.relu_3(out)

        return out
class ResNet_wobn(nn.Module): //定义整个残差网络
    def __init__(self, num_inputs, depth, num_outputs):
        super(ResNet_wobn, self).__init__()
        self.in_planes = 64

        block, num_blocks = cfg(depth)
        self.conv0 = conv3x3(num_inputs, 32, 2) // 64        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2) // 32
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) // 16
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=1)
        self.conv4 = weightNorm(nn.Conv2d(512, 1, 1, 1, 0))
        self.relu_1 = TReLU()
        self.conv1 = weightNorm(nn.Conv2d(65 + 2, 64, 1, 1, 0))        
        self.conv2 = weightNorm(nn.Conv2d(64, 64, 1, 1, 0))
        self.conv3 = weightNorm(nn.Conv2d(64, 32, 1, 1, 0))
        self.relu_2 = TReLU()
        self.relu_3 = TReLU()
        self.relu_4 = TReLU()
        

//_make_layer方法作用是生成多个卷积层,形成一个大的模块。
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []

        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def a2img(self, x):
        tmp = coord.expand(x.shape[0], 2, 64, 64) //coord_ 系列函数可以改变xy轴的位置,x.shape[0]读取矩阵第一维度的长度
        x = x.repeat(64, 64, 1, 1).permute(2, 3, 0, 1) //permute函数可以对任意高维矩阵进行转置
        x = self.relu_2(self.conv1(torch.cat([x, tmp], 1))) //torch.cat()是为了把函数torch.stack()得到tensor进行拼接(concatnate)而存在的。函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。
        x = self.relu_3(self.conv2(x))
        x = self.relu_4(self.conv3(x))
        return x
        
    def forward(self, input):
        x, a = input
        a = self.a2img(a)
        x = self.relu_1(self.conv0(x))
        x = torch.cat([x, a], 1)        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.conv4(x)
        return x.view(x.size(0), 64)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-02 20:48:53  更:2021-08-02 20:49:05 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/22 15:19:02-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码