IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 游戏开发 -> ValueError: Please provide model inputs as a list or tuple of 2 or 3 elements: (input target) -> 正文阅读

[游戏开发]ValueError: Please provide model inputs as a list or tuple of 2 or 3 elements: (input target)

ValueError: Please provide model inputs as a list or tuple of 2 or 3 elements: (input, target)

报错信息

Traceback (most recent call last):  
  File "vae.py", line 170, in <module>  
    train_model(vae)  
  File "vae.py", line 161, in train_model  
    vae.fit(sequence, epochs=epochs)  
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 819, in fit
    use_multiprocessing=use_multiprocessing)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 235, in fit
    use_multiprocessing=use_multiprocessing)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 593, in _process_training_inputs
    use_multiprocessing=use_multiprocessing)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 706, in _process_inputs
    use_multiprocessing=use_multiprocessing)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/data_adapter.py", line 952, in __init__
    **kwargs)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/data_adapter.py", line 767, in __init__
    dataset = standardize_function(dataset)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 660, in standardize_function
    standardize(dataset, extract_tensors_from_dataset=False)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 2346, in _standardize_user_data
    all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 2523, in _build_model_with_inputs
    inputs, targets, _ = training_utils.extract_tensors_from_dataset(inputs)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py", line 1678, in extract_tensors_from_dataset
    inputs, targets, sample_weight = unpack_iterator_input(iterator)
  File "/home/fanjiarong/.local/lib/python2.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py", line 1703, in unpack_iterator_input
    'Received %s' % next_element)

ValueError: Please provide model inputs as a list or tuple of 2 or 3 elements: (input, target) or (input, target, sample_weights)
Received tf.Tensor(
[[0.49803922 0.27058825 0.29803923 ... 0.         0.         0.        ]
 [0.49803922 0.27058825 0.29803923 ... 0.         0.         0.        ]
 [0.49803922 0.27058825 0.29803923 ... 0.         0.         0.        ]
 ...
 [0.49803922 0.27058825 0.29803923 ... 0.         0.         0.        ]
 [0.49803922 0.27058825 0.29803923 ... 0.         0.         0.        ]
 [0.49803922 0.27058825 0.29803923 ... 0.         0.         0.        ]], shape=(128, 7744), dtype=float32)   
2022-04-01 21:25:41.965783: W tensorflow/core/kernels/data/generator_dataset_op.cc:103] Error occurred when finalizing GeneratorDataset iterator: Cancelled: Operation was cancelled

分析程序

初始程序如下所示,在实现AE训练中的fit()函数报错,原因为函数接受的参数为向量化的种子文件(input),希望的参数为(input, target) or (input, target, sample_weights) 。我们需要明确对于分类问题,target应为input对应的类别,训练数据中包含数据及类别,可以直接读取;而对于自编码器target应为input对应的重构,需要将input输入AE网络中进行重构,此处使用fit()函数存在不妥。

class TrainSequence(keras.utils.Sequence):
    def __init__(self, input_file_dir, batch_size, configs = configs):
        pass
        
    def __len__(self):
        pass

    def __getitem__(self, idx):
        # 取一批
        batch_names = self.input_files[
            idx * self.batch_size: (idx + 1) * self.batch_size
        ]
        X_length = self.inpurt_sqrt_dim * self.inpurt_sqrt_dim
        a = vectorize_file(batch_names[0], X_length)
        b = vectorize_file(batch_names[1], X_length)
        test = np.array([a,b])
        # 读取并转为numpy矩阵,
        return np.array([vectorize_file(file_name, X_length) for file_name in batch_names])

sequence = TrainSequence(input_dir_path, batch_size, configs)
# 利用生成器提供数据
vae.fit(sequence, epochs=epochs)

修改程序

通过对自编码器的原理及训练过程进行学习,我们使用新的训练方式,通过GradientTape从头开始写训练/评估的循环,修改之后的代码如下所示。
tensorflow2.0 Seq2Seq多个输入时在model.fit()中使用生成器分批训练大量数据
自编码器(AutoEncoder)入门及TensorFlow实现
Tensorflow2 自定义训练

for epoch in range(epochs):
    print('epoch: ', epoch)
    for step, data in enumerate(sequence):
    	with tf.GradientTape() as tape:
         z_mean, z_log_var, z = self.encoder(data)
         reconstruction = self.decoder(z)
         # 重构损失函数
         # axis = -1指按最后一个dimension
         reconstruction_loss = tf.reduce_mean(keras.losses.binary_crossentropy(data, reconstruction))
         # 正则化约束损失函数
         kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
         # 求和
         kl_loss = tf.reduce_mean(kl_loss)
         total_loss = reconstruction_loss + kl_loss
     grads = tape.gradient(total_loss, self.trainable_weights)
     self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
     self.total_loss_tracker.update_state(total_loss)
     self.reconstruction_loss_tracker.update_state(reconstruction_loss)
     self.kl_loss_tracker.update_state(kl_loss)
vae.save_weights("vae_model_weight")

对于自编码器使用GradientTape从头开始写训练的循环可以避免上述错误。对于训练循环来说,主要包括下面几个部分:

  1. 用一个for循环来控制训练的轮次
  2. 在每轮训练过程中,用一个for循环来控制训练的批次
  3. 在每个批次中,构建一个GradientTape()
  4. 在这个域中,我们调用模型的前向传播,并计算loss
  5. 在域之外,我们计算loss对模型的参数的梯度
  6. 根据梯度,使用优化器更新模型的权重
  游戏开发 最新文章
6、英飞凌-AURIX-TC3XX: PWM实验之使用 GT
泛型自动装箱
CubeMax添加Rtthread操作系统 组件STM32F10
python多线程编程:如何优雅地关闭线程
数据类型隐式转换导致的阻塞
WebAPi实现多文件上传,并附带参数
from origin ‘null‘ has been blocked by
UE4 蓝图调用C++函数(附带项目工程)
Unity学习笔记(一)结构体的简单理解与应用
【Memory As a Programming Concept in C a
上一篇文章      下一篇文章      查看所有文章
加:2022-04-06 16:21:40  更:2022-04-06 16:23:17 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/16 19:59:53-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码