本文所使用的TensorFlow版本为 2.9.0-rc0 众所周知,在TensorFlow2.x中 tf.py_function 可以帮助我们让原本只能在Eager Mode下才能运行的函数体顺利运行在计算图中,获得我们预期的输出。基本使用规则见tf.py_function,这里不再赘述。
由于底层的实现倾向于序列,tf.py_function 对映射类型并不友好,即若输入类型为 dict[str,tf.Tensor] , 是无法直接输出对应的字典形式的dict[str,tf.Tensor] ,这无疑增加了我们的coding难度。
1.处理输入输出structure一致的映射
简单起见,我们构建两个同功能的函数体original_func 和modified_func ,后者支持tf.py_function ,将输入的张量翻倍,不改变输入输出的映射结构。为了体现tf.py_function 的特点与使用难点,我们以tf.string形式作为输入,则函数体内必须先将tf.string转化为一般的tf.in32或tf.float32。具体的代码如下:
import tensorflow as tf
import functools
def tf_py_function_wrapper(func=None):
if func is None:
return functools.partial(tf_py_function_wrapper,)
@functools.wraps(func)
def wrappered(inputs:dict[str,tf.Tensor],output_structure:dict[str,tf.TensorSpec])->dict[str,tf.Tensor]:
inp = tuple(inputs.values())
Tout = tuple(output_structure.values())
flattened_output = tf.py_function(func,inp=inp,Tout=Tout)
return dict(zip(output_structure.keys(),flattened_output))
return wrappered
inputs = {'k1':tf.convert_to_tensor(str(1)),
'k2':tf.convert_to_tensor(str(2)),
'k3':tf.convert_to_tensor(str(3)),
'k4':tf.convert_to_tensor(str(4))}
def original_func(inputs:dict[str,tf.Tensor]):
for key in inputs:
inputs[key] = tf.convert_to_tensor(int(str(inputs[key].numpy(),encoding='UTF-8')))*2
return inputs
print(original_func(inputs))
inputs = {'k1':tf.convert_to_tensor(str(1)),
'k2':tf.convert_to_tensor(str(2)),
'k3':tf.convert_to_tensor(str(3)),
'k4':tf.convert_to_tensor(str(4))}
@tf_py_function_wrapper
def modified_func(*inputs):
return [tf.convert_to_tensor(int(str(item.numpy(),encoding='UTF-8')))*2 for item in inputs]
output_structure = {key:tf.TensorSpec(shape=None,dtype=tf.int32) for key in inputs}
print(modified_func(inputs,output_structure))
其中 original_func() 是我们原本希望实现的函数功能,modified_func() 是对original_func() 的修改,不再是以字典作为输入输出结构,而是以序列(字典的值)作为输入输出,可以被tf.py_function 接收。为了获得字典形式的输出,我们构建自定义的装饰器tf_py_function_wrapper ,将modified_func() 的输出序列依据output_structure 映射回字典。
2. 处理输入输出structure不一致的映射
1中的例子比较简单,输入输出的structure是一致的,但有时候我们需要处理输入输出不一致的情形,比如对输入做多组patch切片,必然无法保证输出的structure不变。以下代码展示了基于1中样例,但输入输出的结构不一致的情况下tf.py_function 的使用。
import tensorflow as tf
import functools
def tf_py_function_wrapper(func=None):
if func is None:
return functools.partial(tf_py_function_wrapper,)
@functools.wraps(func)
def wrappered(inputs:dict[str,tf.Tensor],output_structure:dict[str,tf.TensorSpec])->dict[str,tf.Tensor]:
inp = tuple(inputs.values())
Tout = tuple(output_structure.values())
flattened_output = tf.py_function(func,inp=inp,Tout=Tout)
return dict(zip(output_structure.keys(),flattened_output))
return wrappered
inputs = {'k1':tf.convert_to_tensor(str(1)),
'k2':tf.convert_to_tensor(str(2)),
'k3':tf.convert_to_tensor(str(3)),
'k4':tf.convert_to_tensor(str(4))}
def original_func(inputs:dict[str,tf.Tensor]):
for key in inputs:
inputs[key] = tf.convert_to_tensor(int(str(inputs[key].numpy(),encoding='UTF-8')))*2
inputs["k5"] = tf.random.normal(shape=[],dtype=tf.float32)
inputs["k6"] = tf.random.normal(shape=[],dtype=tf.float16)
return inputs
print(original_func(inputs))
inputs = {'k1':tf.convert_to_tensor(str(1)),
'k2':tf.convert_to_tensor(str(2)),
'k3':tf.convert_to_tensor(str(3)),
'k4':tf.convert_to_tensor(str(4))}
@tf_py_function_wrapper
def modified_func(*inputs):
return [tf.convert_to_tensor(int(str(item.numpy(),encoding='UTF-8')))*2 for item in inputs]+[tf.random.normal(shape=[],dtype=tf.float32),tf.random.normal(shape=[],dtype=tf.float16)]
output_structure = {key:tf.TensorSpec(shape=None,dtype=tf.int32) for key in inputs}|\
{"k5":tf.TensorSpec(shape=None,dtype=tf.float32),"k6":tf.TensorSpec(shape=None,dtype=tf.float16)}
print(modified_func(inputs,output_structure))
其中 original_func() 是我们原本希望实现的函数功能,相比于1中,此处original_func() 返回了额外的内容,导致输入输出结构不一致。 modified_func() 为了应对这种不一致,首先需要在返回的序列中,按顺序添加额外的元素,其次,需要修改output_structure ,使其继续和modified_func() 的返回值序列保存一种一致性,从而使输出序列能被tf_py_function_wrapper 正确还原为字典形式。
总结
对比1和2中样例才可以发现,tf_py_function_wrapper 的设计为tf.py_function 在处理映射类型的输入张量时增添了便捷性,不论输入输出structure一致或者不一致,都可以适用。
|