IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> tensorflow笔记之二十九——通过if条件分支理解graph构建 -> 正文阅读

[人工智能]tensorflow笔记之二十九——通过if条件分支理解graph构建

1.从一个需求开始——分类目打印AUC

最近遇到一个需求,对于一个多标签分类问题,用Estimator训练模型,在模型评估阶段打印各个类目的AUC,AUC的计算逻辑是在定义模型结构的时候实现的,AUC的打印通过estimator hook的方式实现,为了节省计算量在训练阶段不计算auc,我通过if not is_training来对训练和评估构建不同的graph(事实证明这很坑)。代码实现是这样的:

# 定义模型结构
def build_network():
	# 这里是模型结构的定义, 这里省略
	# xxxxxxxx
	pred = tf.nn.sigmoid(output)
	if not is_training: # 注意这里,是个大坑
	    for class_id in auc_idxs:
	        # 分类别计算auc
	        with tf.compat.v1.variable_scope('class_' + str(class_id)):
	        	# 单个类目的掩码
	           	class_mask = tf.transpose(tf.gather(tf.transpose(mask), class_id)) 
	           	# 单个类目有效样本index
	            true_index = tf.reshape(tf.where(tf.reshape(class_mask, [-1]) > 0), [-1]) 
	            # 单个类目的label
	           	class_label = tf.transpose(tf.gather(tf.transpose(label), class_id))
	            class_pred = tf.transpose(tf.gather(tf.transpose(pred), class_id))
	            # 单个类目的有效label
	            true_label = tf.gather(class_label, true_index)
	            true_pred = tf.gather(class_pred, true_index)
	            # 评估auc
	            class_auc, class_update_auc = tf.compat.v1.metrics.auc(real_label, true_pred)
	            # 使用identity对tensor命名
	            tf.identity(class_auc, "auc_classid_" + str(class_id))
	            tf.identity(class_update_auc, "update_auc_classid_" + str(class_id))

# 定义Hook
class LogSessionRunHook(tf.estimator.SessionRunHook):
    def __init__(self, global_batch_size, display_every_n_steps=10, run_type="train", class_ids=None):
        self.global_batch_size = global_batch_size
        self.display_every_n_steps = display_every_n_steps
        self.run_type = run_type
        self.class_ids = class_ids
        
    # 在每一步run前定义fetches
    def before_run(self, run_context):
        self.fetches = []
        for class_id in self.class_ids:
             self.fetches.append('class_{}/auc_classid_{}:0'.format(class_id, class_id))
             self.fetches.append('class_{}/update_auc_classid_{}:0'.format(class_id, class_id))
        return tf.compat.v1.train.SessionRunArgs(fetches=self.fetches)

	# 在每一步run后定义
    def after_run(self, run_context, run_values):
        res = run_values.results
        simple_fetches = [e.split("/")[-1] for e in self.fetches]
        for k, v in zip(simple_fetches, res):
            print(k, ":", v)

训练和预测均正常,能打印AUC指标

2.增加打印各个类目参与评估的样本数量

2.1 第1次错误尝试

由于各个class的样本数量差距较大,同样AUC是0.9,哪个置信哪个不置信不太好评估,所以我想加一个模型AUC评估的样本量统计,以确定统计的指标是否置信。于是我加了一个变量来统计:

# 定义模型结构
def build_network():
	# 这里是模型结构的定义, 这里省略
	# xxxxxxxx
	pred = tf.nn.sigmoid(output)
	if not is_training:
	    for class_id in auc_idxs:
	        # 分类别计算auc
	        with tf.compat.v1.variable_scope('class_' + str(class_id)):
	            # 样本数量
        	    sample_num = tf.compat.v1.get_variable("sample_num_classid_"+str(class_id), 
        	                                           shape=[],
        	                                           initializer=tf.zeros_initializer())
	        	# 单个类目的掩码
	           	class_mask = tf.transpose(tf.gather(tf.transpose(mask), class_id)) 
	           	# 累计当前类目的样本数
	           	updated_sample_num = tf.compat.v1.assign_add(sample_num,
	           	                                             tf.reduce_sum(class_mask))
	           	# 实际需要打印的是updated_sample_num
	           	tf.identity(updated_sample_num, "update_sample_num_classid_"+str(class_id))
	           	# 当前类目样本筛选代码省略
	           	# xxxxxxxx
	           	# 评估auc
	            class_auc, class_update_auc = tf.compat.v1.metrics.auc(real_label, true_pred)
	            # 使用identity对tensor命名
	            tf.identity(class_auc, "auc_classid_" + str(class_id))
	            tf.identity(class_update_auc, "update_auc_classid_" + str(class_id))

# 定义Hook
class LogSessionRunHook(tf.estimator.SessionRunHook):
    # 重复代码省略
    def before_run(self, run_context):
        self.fetches = [] 
        for class_id in self.class_ids:
            # 重复代码省略
            self.fetches.append('class_{}/update_sample_num_classid_{}:0'
                                .format(class_id, class_id))
        return tf.compat.v1.train.SessionRunArgs(fetches=self.fetches)

评估加载模型后,运行报错如下,Key class_0/sample_num_classid_0 not found in checkpoint找不到我们定义的变量:

2021-10-24 11:09:50.418 INFO:tensorflow:Restoring parameters from /cephfs/starxhong/game/mmoe//20211016/model2/model/model.ckpt-41995
2021-10-24 11:09:52.759 2021-10-24 11:09:52.759554: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at save_restore_v2_ops.cc:184 : Not found: Key class_0/sample_num_classid_0 not found in checkpoint
2021-10-24 11:09:52.780 Traceback (most recent call last):
2021-10-24 11:09:52.780 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
2021-10-24 11:09:52.780 return fn(*args)
2021-10-24 11:09:52.780 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1350, in _run_fn
2021-10-24 11:09:52.780 target_list, run_metadata)
2021-10-24 11:09:52.780 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1443, in _call_tf_sessionrun
2021-10-24 11:09:52.780 run_metadata)
2021-10-24 11:09:52.780 tensorflow.python.framework.errors_impl.NotFoundError: 2 root error(s) found.
2021-10-24 11:09:52.780 (0) Not found: Key class_0/sample_num_classid_0 not found in checkpoint
2021-10-24 11:09:52.780 [[{{node save/RestoreV2}}]]
2021-10-24 11:09:52.780 [[save/RestoreV2/_301]]

原因分析:
在训练时,由于有if条件,模型没有保存新增的class_0/sample_num_classid_0:0变量,但在评估构建图的过程中使用get_variable创建了class_0/sample_num_classid_0,模型加载的时候会去模型里找这个变量,当然是找不到了,于是报错了。

2.2 第2次错误尝试

那能不能把变量的创建放到if条件外呢?改为:

# 定义模型结构
def build_network():
	# 这里是模型结构的定义, 这里省略
	# xxxxxxxx
	pred = tf.nn.sigmoid(output)
	# 将get_variable放到if not is_training外面:
	sample_num = {}
	for class_id in auc_idxs:
        with tf.compat.v1.variable_scope('class_' + str(class_id)):
       	    sample_num[class_id] = tf.compat.v1.get_variable("sample_num_classid_"+str(class_id), 
       	                                                     shape=[],
       	                                                     initializer=tf.zeros_initializer())
	if not is_training:
	    for class_id in auc_idxs:
	        # 分类别计算auc
	        with tf.compat.v1.variable_scope('class_' + str(class_id)):
	            
	        	# 单个类目的掩码
	           	class_mask = tf.transpose(tf.gather(tf.transpose(mask), class_id)) 
	           	# 累计当前类目的样本数
	           	updated_sample_num = tf.compat.v1.assign_add(sample_num[class_id],
	           	                                             tf.reduce_sum(class_mask))
	           	# 实际需要打印的是updated_sample_num
	           	tf.identity(updated_sample_num, "update_sample_num_classid_"+str(class_id))
	           	# 当前类目样本筛选代码省略
	           	# 评估auc
	            class_auc, class_update_auc = tf.compat.v1.metrics.auc(real_label, true_pred)
	            # 使用identity对tensor命名
	            tf.identity(class_auc, "auc_classid_" + str(class_id))
	            tf.identity(class_update_auc, "update_auc_classid_" + str(class_id))

会报错:

2021-10-26 16:03:57.744 Traceback (most recent call last):
2021-10-26 16:03:57.744 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py", line 305, in __init__
2021-10-26 16:03:57.744 fetch, allow_tensor=True, allow_operation=True))
2021-10-26 16:03:57.744 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 3604, in as_graph_element
2021-10-26 16:03:57.744 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
2021-10-26 16:03:57.744 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 3649, in _as_graph_element_locked
2021-10-26 16:03:57.744 "graph." % (repr(name), repr(op_name)))
2021-10-26 16:03:57.744 KeyError: "The name 'class_0/class_auc_classid_0:0' refers to a Tensor which does not exist. The operation, 'class_0/class_auc_classid_0:0', does not exist in the graph."

这次是class_0/class_auc_classid_0:0指向了一个不存在的tensor(也就是class_auc),这是之前不会错误的地方,为什么会出现tensor不存在的问题暂时还清楚,我猜跟variable_scope有关,于是我做了以下修改:

sample_num = {}
	for class_id in auc_idxs:
		# 改variable_scope名称
        with tf.compat.v1.variable_scope('class_new_' + str(class_id)):
       	    # xxxx

再次运行就能正常跑了。至于为什么正常了,我也很奇怪,同一个名称的variable_scope可以多次使用肯定是没问题的,我也猜测过get_variable默认是global_variable,而class_auc是local_variable,但测试过同一个variable_scope下可以既有global_variable又有local_variable。

2.3 正确的打开方式

上面的方法虽然能正常跑了,但是很蹩脚,代码看起来很奇怪。上面的问题归根到底是由于if not is_training这个if判断条件导致的训练和预测模型结构不一致。从上面的经验来看,get_variable不能只出现在评估阶段,否则会出现key not found in checkpoint的问题,但是如果不涉及到get_variable,评估阶段仅仅多一些tensor计算(比如最上面的的auc计算),好像也能正常跑。
不过,tensorflow的graph中不应该出现python的if条件分支语句,如果一定要加分支,需要用tf.cond()函数来处理,graph应该是仅仅tensor和op的组合。而且,一般tf.cond()作用于input内容的条件判断上,而不是像我们这种is_training的判断。正确的打开方式是,在构建graph的时候不区分训练还是预测、评估,可以通过hook中加入判断条件来决定在评估阶段跑auc计算任务。
代码如下:

# 定义模型结构
def build_network():
	# 这里是模型结构的定义, 这里省略
	# xxxxxxxx
	pred = tf.nn.sigmoid(output)
    for class_id in auc_idxs:
        # 分类别计算auc
        with tf.compat.v1.variable_scope('class_' + str(class_id)):
       	    # 单个类目的掩码
          	class_mask = tf.transpose(tf.gather(tf.transpose(mask), class_id)) 
          	# 样本数量
        	sample_num = tf.compat.v1.get_variable("sample_num_classid_"+str(class_id), 
        	                                           shape=[],
        	                                           initializer=tf.zeros_initializer())
            # 累计当前类目的样本数
           	updated_sample_num = tf.compat.v1.assign_add(sample_num,
           	                                             tf.reduce_sum(class_mask))
           	# 实际需要打印的是updated_sample_num
           	tf.identity(updated_sample_num, "update_sample_num_classid_"+str(class_id))
          	# 单个类目有效样本index
            true_index = tf.reshape(tf.where(tf.reshape(class_mask, [-1]) > 0), [-1]) 
            # 单个类目的label
          	class_label = tf.transpose(tf.gather(tf.transpose(label), class_id))
            class_pred = tf.transpose(tf.gather(tf.transpose(pred), class_id))
            # 单个类目的有效label
            true_label = tf.gather(class_label, true_index)
            true_pred = tf.gather(class_pred, true_index)
            # 评估auc
            class_auc, class_update_auc = tf.compat.v1.metrics.auc(real_label, true_pred)
            # 使用identity对tensor命名
            tf.identity(class_auc, "auc_classid_" + str(class_id))
            tf.identity(class_update_auc, "update_auc_classid_" + str(class_id))

# 定义Hook
class LogSessionRunHook(tf.estimator.SessionRunHook):
    def __init__(self, global_batch_size, display_every_n_steps=10, run_type="train", class_ids=None):
        self.global_batch_size = global_batch_size
        self.display_every_n_steps = display_every_n_steps
        self.run_type = run_type
        self.class_ids = class_ids
        
    # 在每一步run前定义fetches
    def before_run(self, run_context):
        self.fetches = []
        if self.run_type == 'eval':
            for class_id in self.class_ids:
                self.fetches.append('class_{}/auc_classid_{}:0'.format(class_id, class_id))
                self.fetches.append('class_{}/update_auc_classid_{}:0'.format(class_id, class_id))
        return tf.compat.v1.train.SessionRunArgs(fetches=self.fetches)

	# 在每一步run后定义
    def after_run(self, run_context, run_values):
        if self.run_type == "eval":
            res = run_values.results
            simple_fetches = [e.split("/")[-1] for e in self.fetches]
            for k, v in zip(simple_fetches, res):
                print(k, ":", v)
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-27 12:50:24  更:2021-10-27 12:50:32 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 8:08:01-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码