IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【MindSpore产品】【数据处理功能】加入数据增强之后,报出卷积输入类型不同的问题 -> 正文阅读

[人工智能]【MindSpore产品】【数据处理功能】加入数据增强之后,报出卷积输入类型不同的问题

【功能模块】

# 图像增强
trans = [
    transforms.RandomCrop((32, 32), (4, 4, 4, 4), fill_value=(255,255,255)), # 对图像进行自动裁剪
    transforms.RandomHorizontalFlip(prob=0.5), # 对图像进行随机水平翻转
    transforms.RandomRotation(degrees=20, fill_value=(255,255,255)),
    # transforms.HWC2CHW(), # (h, w, c)转换为(c, h, w)
]
# 下载解压并加载CIFAR-10训练数据集
dataset_train = Cifar10(path=data_dir, split='train', batch_size=6, shuffle=True, resize=32, download=True, transform=trans)
ds_train = dataset_train.run()
model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])

【操作步骤&问题现象】

Traceback (most recent call last):
? File "F:/8.Learning Task/MindSpore/ResNet/train.py", line 49, in <module>
? ? model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])
? File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 906, in train
? ? sink_size=sink_size)
? File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 87, in wrapper
? ? func(self, *args, **kwargs)
? File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 546, in _train
? ? self._train_process(epoch, train_dataset, list_callback, cb_params)
? File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 794, in _train_process
? ? outputs = self._train_network(*next_element)
? File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 586, in __call__
? ? out = self.compile_and_run(*args)
? File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 964, in compile_and_run
? ? self.compile(*inputs)
? File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 937, in compile
? ? _cell_graph_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode)
? File "D:\Anaconda1\lib\site-packages\mindspore\common\api.py", line 1006, in compile
? ? result = self._graph_executor.compile(obj, args_list, phase, self._use_vm_mode())
TypeError: mindspore\core\utils\check_convert_utils.cc:701 _CheckTypeSame] For primitive[Conv2D], the input type must be same.
name:[w]:Ref[Tensor(F32)].
name:[x]:Tensor[UInt8].

WARNING: Logging before InitGoogleLogging() is written to STDERR
[CRITICAL] CORE(22848,1,?):2022-6-6 12:59:53 [mindspore\core\utils\check_convert_utils.cc:701] _CheckTypeSame] For primitive[Conv2D], the input type must be same.
name:[w]:Ref[Tensor(F32)].
name:[x]:Tensor[UInt8].

【日志信息】(可选,上传日志内容或者附件)

不知该如何让input的类型相同,求大佬们能看看,给个办法,谢谢!!!

总体代码如下:

# train.py

from mindvision.dataset import Cifar10
import mindspore.dataset.vision.c_transforms as transforms

# 数据集根目录
data_dir = "./datasets"
# 图像增强
# 图像增强
trans = [
    transforms.RandomCrop((32, 32), (4, 4, 4, 4), fill_value=(255,255,255)), # 对图像进行自动裁剪
    transforms.RandomHorizontalFlip(prob=0.5), # 对图像进行随机水平翻转
    transforms.RandomRotation(degrees=20, fill_value=(255,255,255)),
    # transforms.HWC2CHW(), # (h, w, c)转换为(c, h, w)
]

# 下载解压并加载CIFAR-10训练数据集
dataset_train = Cifar10(path=data_dir, split='train', batch_size=6, shuffle=True, resize=32, download=True, transform=trans)
ds_train = dataset_train.run()
step_size = ds_train.get_dataset_size()
# 下载解压并加载CIFAR-10测试数据集
dataset_val = Cifar10(path=data_dir, split='test', batch_size=6, resize=32, download=True)
ds_val = dataset_val.run()


from mindspore.train import Model
from mindvision.engine.callback import ValAccMonitor
from mindvision.classification.models.head import DenseHead
from mindspore import nn
from ResNet.resnet import resnet50

# 定义ResNet50网络
network = resnet50(pretrained=True)

# 全连接层输入层的大小
in_channel = network.head.dense.in_channels
head = DenseHead(input_channel=in_channel, num_classes=10)
# 重置全连接层
network.head = head
# 设置学习率
num_epochs = 40
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size * num_epochs,
                        step_per_epoch=step_size, decay_epoch=num_epochs)
# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 实例化模型
model = Model(network, loss, opt, metrics={"Accuracy": nn.Accuracy()})
# 模型训练
model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])

?从报错来看出问题代码应该在你自定义网络resnet里,所以我在本地尝试未复现成功;然后从这个报错来看应该是你的卷积算子CONV2D输入不一致,有试过在CONV2D算子前将所有的输入转为float32格式然后继续呢,如果不行的话麻烦您再提供一下自定义resnet网络的脚本。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-10-31 11:56:46  更:2022-10-31 12:01:52 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 3:39:04-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计