【功能模块】
# 图像增强
trans = [
transforms.RandomCrop((32, 32), (4, 4, 4, 4), fill_value=(255,255,255)), # 对图像进行自动裁剪
transforms.RandomHorizontalFlip(prob=0.5), # 对图像进行随机水平翻转
transforms.RandomRotation(degrees=20, fill_value=(255,255,255)),
# transforms.HWC2CHW(), # (h, w, c)转换为(c, h, w)
]
# 下载解压并加载CIFAR-10训练数据集
dataset_train = Cifar10(path=data_dir, split='train', batch_size=6, shuffle=True, resize=32, download=True, transform=trans)
ds_train = dataset_train.run()
model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])
【操作步骤&问题现象】
Traceback (most recent call last): ? File "F:/8.Learning Task/MindSpore/ResNet/train.py", line 49, in <module> ? ? model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)]) ? File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 906, in train ? ? sink_size=sink_size) ? File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 87, in wrapper ? ? func(self, *args, **kwargs) ? File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 546, in _train ? ? self._train_process(epoch, train_dataset, list_callback, cb_params) ? File "D:\Anaconda1\lib\site-packages\mindspore\train\model.py", line 794, in _train_process ? ? outputs = self._train_network(*next_element) ? File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 586, in __call__ ? ? out = self.compile_and_run(*args) ? File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 964, in compile_and_run ? ? self.compile(*inputs) ? File "D:\Anaconda1\lib\site-packages\mindspore\nn\cell.py", line 937, in compile ? ? _cell_graph_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) ? File "D:\Anaconda1\lib\site-packages\mindspore\common\api.py", line 1006, in compile ? ? result = self._graph_executor.compile(obj, args_list, phase, self._use_vm_mode()) TypeError: mindspore\core\utils\check_convert_utils.cc:701 _CheckTypeSame] For primitive[Conv2D], the input type must be same. name:[w]:Ref[Tensor(F32)]. name:[x]:Tensor[UInt8].
WARNING: Logging before InitGoogleLogging() is written to STDERR [CRITICAL] CORE(22848,1,?):2022-6-6 12:59:53 [mindspore\core\utils\check_convert_utils.cc:701] _CheckTypeSame] For primitive[Conv2D], the input type must be same. name:[w]:Ref[Tensor(F32)]. name:[x]:Tensor[UInt8].
【日志信息】(可选,上传日志内容或者附件)
不知该如何让input的类型相同,求大佬们能看看,给个办法,谢谢!!!
总体代码如下:
# train.py
from mindvision.dataset import Cifar10
import mindspore.dataset.vision.c_transforms as transforms
# 数据集根目录
data_dir = "./datasets"
# 图像增强
# 图像增强
trans = [
transforms.RandomCrop((32, 32), (4, 4, 4, 4), fill_value=(255,255,255)), # 对图像进行自动裁剪
transforms.RandomHorizontalFlip(prob=0.5), # 对图像进行随机水平翻转
transforms.RandomRotation(degrees=20, fill_value=(255,255,255)),
# transforms.HWC2CHW(), # (h, w, c)转换为(c, h, w)
]
# 下载解压并加载CIFAR-10训练数据集
dataset_train = Cifar10(path=data_dir, split='train', batch_size=6, shuffle=True, resize=32, download=True, transform=trans)
ds_train = dataset_train.run()
step_size = ds_train.get_dataset_size()
# 下载解压并加载CIFAR-10测试数据集
dataset_val = Cifar10(path=data_dir, split='test', batch_size=6, resize=32, download=True)
ds_val = dataset_val.run()
from mindspore.train import Model
from mindvision.engine.callback import ValAccMonitor
from mindvision.classification.models.head import DenseHead
from mindspore import nn
from ResNet.resnet import resnet50
# 定义ResNet50网络
network = resnet50(pretrained=True)
# 全连接层输入层的大小
in_channel = network.head.dense.in_channels
head = DenseHead(input_channel=in_channel, num_classes=10)
# 重置全连接层
network.head = head
# 设置学习率
num_epochs = 40
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size * num_epochs,
step_per_epoch=step_size, decay_epoch=num_epochs)
# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 实例化模型
model = Model(network, loss, opt, metrics={"Accuracy": nn.Accuracy()})
# 模型训练
model.train(num_epochs, ds_train, callbacks=[ValAccMonitor(model, ds_val, num_epochs)])
?从报错来看出问题代码应该在你自定义网络resnet里,所以我在本地尝试未复现成功;然后从这个报错来看应该是你的卷积算子CONV2D输入不一致,有试过在CONV2D算子前将所有的输入转为float32格式然后继续呢,如果不行的话麻烦您再提供一下自定义resnet网络的脚本。
|