IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 记录若干`tf.py_function`的使用的方式,便于查阅 -> 正文阅读

[人工智能]记录若干`tf.py_function`的使用的方式,便于查阅

本文所使用的TensorFlow版本为 2.9.0-rc0
众所周知,在TensorFlow2.x中 tf.py_function可以帮助我们让原本只能在Eager Mode下才能运行的函数体顺利运行在计算图中,获得我们预期的输出。基本使用规则见tf.py_function,这里不再赘述。

由于底层的实现倾向于序列,tf.py_function对映射类型并不友好,即若输入类型为 dict[str,tf.Tensor], 是无法直接输出对应的字典形式的dict[str,tf.Tensor],这无疑增加了我们的coding难度。

1.处理输入输出structure一致的映射

简单起见,我们构建两个同功能的函数体original_funcmodified_func,后者支持tf.py_function,将输入的张量翻倍,不改变输入输出的映射结构。为了体现tf.py_function的特点与使用难点,我们以tf.string形式作为输入,则函数体内必须先将tf.string转化为一般的tf.in32或tf.float32。具体的代码如下:

import tensorflow as tf 
import functools
def tf_py_function_wrapper(func=None):
    # since tf.py_function can not deal with dict directly, and its using form is not easy
    # here we make this wrapper, it can trans `func`'s output 
    # all tensors that a user want to use by calling their `numpy()` functions should become the inputs of the wrapped `func`, otherwise, `numpy()` will not work
    # since tf.nest.flatten  tf.nest.pack_sequence_as will sort dict structure's `keys` automaticlly, we do not use tf.nest here to avoid unexpected behavior
    if func is None:
        return functools.partial(tf_py_function_wrapper,)
    @functools.wraps(func)
    def wrappered(inputs:dict[str,tf.Tensor],output_structure:dict[str,tf.TensorSpec])->dict[str,tf.Tensor]:
        inp = tuple(inputs.values())
        Tout = tuple(output_structure.values())
        flattened_output = tf.py_function(func,inp=inp,Tout=Tout)
        return dict(zip(output_structure.keys(),flattened_output))
    return wrappered  

inputs = {'k1':tf.convert_to_tensor(str(1)),
          'k2':tf.convert_to_tensor(str(2)),
          'k3':tf.convert_to_tensor(str(3)),
          'k4':tf.convert_to_tensor(str(4))}
#@tf.function 
def original_func(inputs:dict[str,tf.Tensor]): # 各元素*2
    # @tf.function会导致函数运行在图模式, 此时tf.Tensor不存在numpy()方法, 因此当前函数尽可能运行在Eager模式
    for key in inputs:
        inputs[key] = tf.convert_to_tensor(int(str(inputs[key].numpy(),encoding='UTF-8')))*2
    return inputs
print(original_func(inputs))

inputs = {'k1':tf.convert_to_tensor(str(1)),
          'k2':tf.convert_to_tensor(str(2)),
          'k3':tf.convert_to_tensor(str(3)),
          'k4':tf.convert_to_tensor(str(4))}
@tf_py_function_wrapper
def modified_func(*inputs): # 各元素*2
    return [tf.convert_to_tensor(int(str(item.numpy(),encoding='UTF-8')))*2 for item in inputs]
output_structure = {key:tf.TensorSpec(shape=None,dtype=tf.int32) for key in inputs}
print(modified_func(inputs,output_structure))

其中 original_func()是我们原本希望实现的函数功能,modified_func()是对original_func()的修改,不再是以字典作为输入输出结构,而是以序列(字典的值)作为输入输出,可以被tf.py_function接收。为了获得字典形式的输出,我们构建自定义的装饰器tf_py_function_wrapper,将modified_func()的输出序列依据output_structure映射回字典。

2. 处理输入输出structure不一致的映射

1中的例子比较简单,输入输出的structure是一致的,但有时候我们需要处理输入输出不一致的情形,比如对输入做多组patch切片,必然无法保证输出的structure不变。以下代码展示了基于1中样例,但输入输出的结构不一致的情况下tf.py_function的使用。

import tensorflow as tf 
import functools
def tf_py_function_wrapper(func=None):
    # since tf.py_function can not deal with dict directly, and its using form is not easy
    # here we make this wrapper, it can trans `func`'s output 
    # all tensors that a user want to use by calling their `numpy()` functions should become the inputs of the wrapped `func`, otherwise, `numpy()` will not work
    # since tf.nest.flatten  tf.nest.pack_sequence_as will sort dict structure's `keys` automaticlly, we do not use tf.nest here to avoid unexpected behavior
    if func is None:
        return functools.partial(tf_py_function_wrapper,)
    @functools.wraps(func)
    def wrappered(inputs:dict[str,tf.Tensor],output_structure:dict[str,tf.TensorSpec])->dict[str,tf.Tensor]:
        inp = tuple(inputs.values())
        Tout = tuple(output_structure.values())
        flattened_output = tf.py_function(func,inp=inp,Tout=Tout)
        return dict(zip(output_structure.keys(),flattened_output))
    return wrappered  

inputs = {'k1':tf.convert_to_tensor(str(1)),
          'k2':tf.convert_to_tensor(str(2)),
          'k3':tf.convert_to_tensor(str(3)),
          'k4':tf.convert_to_tensor(str(4))}
#@tf.function 
def original_func(inputs:dict[str,tf.Tensor]): # 各元素*2
    # @tf.function会导致函数运行在图模式, 此时tf.Tensor不存在numpy()方法, 因此当前函数尽可能运行在Eager模式
    for key in inputs:
        inputs[key] = tf.convert_to_tensor(int(str(inputs[key].numpy(),encoding='UTF-8')))*2
    inputs["k5"] = tf.random.normal(shape=[],dtype=tf.float32)
    inputs["k6"] = tf.random.normal(shape=[],dtype=tf.float16)
    return inputs
print(original_func(inputs))

inputs = {'k1':tf.convert_to_tensor(str(1)),
          'k2':tf.convert_to_tensor(str(2)),
          'k3':tf.convert_to_tensor(str(3)),
          'k4':tf.convert_to_tensor(str(4))}
@tf_py_function_wrapper
def modified_func(*inputs): # 各元素*2
    return [tf.convert_to_tensor(int(str(item.numpy(),encoding='UTF-8')))*2 for item in inputs]+[tf.random.normal(shape=[],dtype=tf.float32),tf.random.normal(shape=[],dtype=tf.float16)]
output_structure = {key:tf.TensorSpec(shape=None,dtype=tf.int32) for key in inputs}|\
                   {"k5":tf.TensorSpec(shape=None,dtype=tf.float32),"k6":tf.TensorSpec(shape=None,dtype=tf.float16)}
print(modified_func(inputs,output_structure))

其中 original_func()是我们原本希望实现的函数功能,相比于1中,此处original_func()返回了额外的内容,导致输入输出结构不一致。
modified_func()为了应对这种不一致,首先需要在返回的序列中,按顺序添加额外的元素,其次,需要修改output_structure,使其继续和modified_func()的返回值序列保存一种一致性,从而使输出序列能被tf_py_function_wrapper正确还原为字典形式。

总结

对比1和2中样例才可以发现,tf_py_function_wrapper的设计为tf.py_function在处理映射类型的输入张量时增添了便捷性,不论输入输出structure一致或者不一致,都可以适用。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-07 11:11:01  更:2022-05-07 11:14:02 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/4 15:44:36-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码