1.从一个需求开始——分类目打印AUC
最近遇到一个需求,对于一个多标签分类问题,用Estimator训练模型,在模型评估阶段打印各个类目的AUC,AUC的计算逻辑是在定义模型结构的时候实现的,AUC的打印通过estimator hook的方式实现,为了节省计算量在训练阶段不计算auc,我通过if not is_training来对训练和评估构建不同的graph(事实证明这很坑)。代码实现是这样的:
def build_network():
pred = tf.nn.sigmoid(output)
if not is_training:
for class_id in auc_idxs:
with tf.compat.v1.variable_scope('class_' + str(class_id)):
class_mask = tf.transpose(tf.gather(tf.transpose(mask), class_id))
true_index = tf.reshape(tf.where(tf.reshape(class_mask, [-1]) > 0), [-1])
class_label = tf.transpose(tf.gather(tf.transpose(label), class_id))
class_pred = tf.transpose(tf.gather(tf.transpose(pred), class_id))
true_label = tf.gather(class_label, true_index)
true_pred = tf.gather(class_pred, true_index)
class_auc, class_update_auc = tf.compat.v1.metrics.auc(real_label, true_pred)
tf.identity(class_auc, "auc_classid_" + str(class_id))
tf.identity(class_update_auc, "update_auc_classid_" + str(class_id))
class LogSessionRunHook(tf.estimator.SessionRunHook):
def __init__(self, global_batch_size, display_every_n_steps=10, run_type="train", class_ids=None):
self.global_batch_size = global_batch_size
self.display_every_n_steps = display_every_n_steps
self.run_type = run_type
self.class_ids = class_ids
def before_run(self, run_context):
self.fetches = []
for class_id in self.class_ids:
self.fetches.append('class_{}/auc_classid_{}:0'.format(class_id, class_id))
self.fetches.append('class_{}/update_auc_classid_{}:0'.format(class_id, class_id))
return tf.compat.v1.train.SessionRunArgs(fetches=self.fetches)
def after_run(self, run_context, run_values):
res = run_values.results
simple_fetches = [e.split("/")[-1] for e in self.fetches]
for k, v in zip(simple_fetches, res):
print(k, ":", v)
训练和预测均正常,能打印AUC指标
2.增加打印各个类目参与评估的样本数量
2.1 第1次错误尝试
由于各个class的样本数量差距较大,同样AUC是0.9,哪个置信哪个不置信不太好评估,所以我想加一个模型AUC评估的样本量统计,以确定统计的指标是否置信。于是我加了一个变量来统计:
def build_network():
pred = tf.nn.sigmoid(output)
if not is_training:
for class_id in auc_idxs:
with tf.compat.v1.variable_scope('class_' + str(class_id)):
sample_num = tf.compat.v1.get_variable("sample_num_classid_"+str(class_id),
shape=[],
initializer=tf.zeros_initializer())
class_mask = tf.transpose(tf.gather(tf.transpose(mask), class_id))
updated_sample_num = tf.compat.v1.assign_add(sample_num,
tf.reduce_sum(class_mask))
tf.identity(updated_sample_num, "update_sample_num_classid_"+str(class_id))
class_auc, class_update_auc = tf.compat.v1.metrics.auc(real_label, true_pred)
tf.identity(class_auc, "auc_classid_" + str(class_id))
tf.identity(class_update_auc, "update_auc_classid_" + str(class_id))
class LogSessionRunHook(tf.estimator.SessionRunHook):
def before_run(self, run_context):
self.fetches = []
for class_id in self.class_ids:
self.fetches.append('class_{}/update_sample_num_classid_{}:0'
.format(class_id, class_id))
return tf.compat.v1.train.SessionRunArgs(fetches=self.fetches)
评估加载模型后,运行报错如下,Key class_0/sample_num_classid_0 not found in checkpoint找不到我们定义的变量:
2021-10-24 11:09:50.418 INFO:tensorflow:Restoring parameters from /cephfs/starxhong/game/mmoe//20211016/model2/model/model.ckpt-41995
2021-10-24 11:09:52.759 2021-10-24 11:09:52.759554: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at save_restore_v2_ops.cc:184 : Not found: Key class_0/sample_num_classid_0 not found in checkpoint
2021-10-24 11:09:52.780 Traceback (most recent call last):
2021-10-24 11:09:52.780 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
2021-10-24 11:09:52.780 return fn(*args)
2021-10-24 11:09:52.780 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1350, in _run_fn
2021-10-24 11:09:52.780 target_list, run_metadata)
2021-10-24 11:09:52.780 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py", line 1443, in _call_tf_sessionrun
2021-10-24 11:09:52.780 run_metadata)
2021-10-24 11:09:52.780 tensorflow.python.framework.errors_impl.NotFoundError: 2 root error(s) found.
2021-10-24 11:09:52.780 (0) Not found: Key class_0/sample_num_classid_0 not found in checkpoint
2021-10-24 11:09:52.780 [[{{node save/RestoreV2}}]]
2021-10-24 11:09:52.780 [[save/RestoreV2/_301]]
原因分析: 在训练时,由于有if条件,模型没有保存新增的class_0/sample_num_classid_0:0变量,但在评估构建图的过程中使用get_variable创建了class_0/sample_num_classid_0,模型加载的时候会去模型里找这个变量,当然是找不到了,于是报错了。
2.2 第2次错误尝试
那能不能把变量的创建放到if条件外呢?改为:
def build_network():
pred = tf.nn.sigmoid(output)
sample_num = {}
for class_id in auc_idxs:
with tf.compat.v1.variable_scope('class_' + str(class_id)):
sample_num[class_id] = tf.compat.v1.get_variable("sample_num_classid_"+str(class_id),
shape=[],
initializer=tf.zeros_initializer())
if not is_training:
for class_id in auc_idxs:
with tf.compat.v1.variable_scope('class_' + str(class_id)):
class_mask = tf.transpose(tf.gather(tf.transpose(mask), class_id))
updated_sample_num = tf.compat.v1.assign_add(sample_num[class_id],
tf.reduce_sum(class_mask))
tf.identity(updated_sample_num, "update_sample_num_classid_"+str(class_id))
class_auc, class_update_auc = tf.compat.v1.metrics.auc(real_label, true_pred)
tf.identity(class_auc, "auc_classid_" + str(class_id))
tf.identity(class_update_auc, "update_auc_classid_" + str(class_id))
会报错:
2021-10-26 16:03:57.744 Traceback (most recent call last):
2021-10-26 16:03:57.744 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/client/session.py", line 305, in __init__
2021-10-26 16:03:57.744 fetch, allow_tensor=True, allow_operation=True))
2021-10-26 16:03:57.744 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 3604, in as_graph_element
2021-10-26 16:03:57.744 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
2021-10-26 16:03:57.744 File "/usr/local/lib64/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 3649, in _as_graph_element_locked
2021-10-26 16:03:57.744 "graph." % (repr(name), repr(op_name)))
2021-10-26 16:03:57.744 KeyError: "The name 'class_0/class_auc_classid_0:0' refers to a Tensor which does not exist. The operation, 'class_0/class_auc_classid_0:0', does not exist in the graph."
这次是class_0/class_auc_classid_0:0指向了一个不存在的tensor(也就是class_auc),这是之前不会错误的地方,为什么会出现tensor不存在的问题暂时还清楚,我猜跟variable_scope有关,于是我做了以下修改:
sample_num = {}
for class_id in auc_idxs:
with tf.compat.v1.variable_scope('class_new_' + str(class_id)):
再次运行就能正常跑了。至于为什么正常了,我也很奇怪,同一个名称的variable_scope可以多次使用肯定是没问题的,我也猜测过get_variable默认是global_variable,而class_auc是local_variable,但测试过同一个variable_scope下可以既有global_variable又有local_variable。
2.3 正确的打开方式
上面的方法虽然能正常跑了,但是很蹩脚,代码看起来很奇怪。上面的问题归根到底是由于if not is_training这个if判断条件导致的训练和预测模型结构不一致。从上面的经验来看,get_variable不能只出现在评估阶段,否则会出现key not found in checkpoint的问题,但是如果不涉及到get_variable,评估阶段仅仅多一些tensor计算(比如最上面的的auc计算),好像也能正常跑。 不过,tensorflow的graph中不应该出现python的if条件分支语句,如果一定要加分支,需要用tf.cond()函数来处理,graph应该是仅仅tensor和op的组合。而且,一般tf.cond()作用于input内容的条件判断上,而不是像我们这种is_training的判断。正确的打开方式是,在构建graph的时候不区分训练还是预测、评估,可以通过hook中加入判断条件来决定在评估阶段跑auc计算任务。 代码如下:
def build_network():
pred = tf.nn.sigmoid(output)
for class_id in auc_idxs:
with tf.compat.v1.variable_scope('class_' + str(class_id)):
class_mask = tf.transpose(tf.gather(tf.transpose(mask), class_id))
sample_num = tf.compat.v1.get_variable("sample_num_classid_"+str(class_id),
shape=[],
initializer=tf.zeros_initializer())
updated_sample_num = tf.compat.v1.assign_add(sample_num,
tf.reduce_sum(class_mask))
tf.identity(updated_sample_num, "update_sample_num_classid_"+str(class_id))
true_index = tf.reshape(tf.where(tf.reshape(class_mask, [-1]) > 0), [-1])
class_label = tf.transpose(tf.gather(tf.transpose(label), class_id))
class_pred = tf.transpose(tf.gather(tf.transpose(pred), class_id))
true_label = tf.gather(class_label, true_index)
true_pred = tf.gather(class_pred, true_index)
class_auc, class_update_auc = tf.compat.v1.metrics.auc(real_label, true_pred)
tf.identity(class_auc, "auc_classid_" + str(class_id))
tf.identity(class_update_auc, "update_auc_classid_" + str(class_id))
class LogSessionRunHook(tf.estimator.SessionRunHook):
def __init__(self, global_batch_size, display_every_n_steps=10, run_type="train", class_ids=None):
self.global_batch_size = global_batch_size
self.display_every_n_steps = display_every_n_steps
self.run_type = run_type
self.class_ids = class_ids
def before_run(self, run_context):
self.fetches = []
if self.run_type == 'eval':
for class_id in self.class_ids:
self.fetches.append('class_{}/auc_classid_{}:0'.format(class_id, class_id))
self.fetches.append('class_{}/update_auc_classid_{}:0'.format(class_id, class_id))
return tf.compat.v1.train.SessionRunArgs(fetches=self.fetches)
def after_run(self, run_context, run_values):
if self.run_type == "eval":
res = run_values.results
simple_fetches = [e.split("/")[-1] for e in self.fetches]
for k, v in zip(simple_fetches, res):
print(k, ":", v)
|