本文介绍常用的Tensorflow debug可能涉及的api
一、 查看tensorflow版本以及安装路径
import tensorflow as tf
print(tf.__path__)
print(tf.__version__)
二、怎么打印tensor
- tensorflow 2.0
helloworld = tf.constant("hello, TensorFlow")
print("Tensor:", helloworld)
print("Value :", helloworld.numpy())
输出:
Tensor: tf.Tensor(b'hello, TensorFlow', shape=(), dtype=string)
Value : b'hello, TensorFlow'
- tensorflow 1.x
tf.Print()函数参数
Print(
input_,
data,
message=None,
first_n=None,
summarize=None,
name=None
)
参数:
- input_:通过这个操作的张量。
- data:计算 op 时要打印的张量列表。
- message:一个字符串,错误消息的前缀。
- first_n:只记录 first_n 次数。负数日志,这是默认的。
- summarize:只打印每个张量的许多条目。如果没有,则每个输入张量最多打印3个元素。
- name:操作的名称(可选)。
返回: 该操作将返回与 input_ 相同的张量。
用法示例:
import tensorflow as tf
a = tf.Variable(tf.random_normal([3, 3, 1, 64], stddev=0.1))
a = tf.Print(a, [a], "a: ",summarize=9)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
sess.run(a)
model = tf.Print(input_=model,
data=[tf.argmax(model,1)],
message='y_hat=',
summarize=10,
first_n=5
)
输出
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 0 0 7 0 0 0 0 0 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 7 7 1 8 7 2 7 7 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[4 8 0 6 1 8 1 0 7 0...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[0 0 1 0 0 0 0 5 7 5...]
I tensorflow/core/kernels/logging_ops.cc:79] y_hat=[9 2 2 8 8 6 6 1 7 7...]
to be continued…
References 精通Tensorflow 1.x eat TensorFlow2 in 30 days
|