通过session_config 开启XLA
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
import tensorflow as tf
session_config = tf.ConfigProto(allow_soft_placement=True,
log_device_placement=False)
# 开启XLA
session_config.graph_options.optimizer_options.global_jit_level = config_pb2.OptimizerOptions.ON_1
# 或者开启XLA默认的优化
# session_config.graph_options.optimizer_options.global_jit_level = config_pb2.OptimizerOptions.DEFAULT
# 关闭XLA
# session_config.graph_options.optimizer_options.global_jit_level = config_pb2.OptimizerOptions.OFF
config=tf.estimator.RunConfig(
session_config=session_config,
log_step_count_steps=1,
protocol="grpc",
keep_checkpoint_max=5,
save_checkpoints_steps=10
)
estimator = tf.estimator.Estimator(
model_fn=model_fn, #用户定义的model_fn
model_dir=model_dir, # 保存ckpt 的路径
config=config,
)
特定op关闭 XLA
from tensorflow.python.compiler.xla import jit
with jit.experimental_jit_scope(compile_ops=False):
a = tf.add(1, 2)
|