tf.case
是 TensorFlow 中的一个条件操作,用于根据多个条件选择执行不同的函数。它类似于 Python 中的 if-elif-else
结构,但能够在 TensorFlow 的计算图中动态选择分支。
参数说明
tf.case
的主要参数如下:
tf.case(
pred_fn_pairs,
default=None,
exclusive=False,
strict=False,
name='case'
)
• pred_fn_pairs: 一个由 (条件, 函数)
对组成的列表或字典。条件
是一个布尔张量,函数
是一个可调用的对象(通常是无参数的 lambda 或函数)。tf.case
会执行第一个条件为 True
的函数。
• default (可选): 默认函数,当所有条件为 False
时执行。如果未提供且所有条件为 False
,会报错。
• exclusive (可选): 布尔值。如果为 True
,则确保只有一个条件为 True
(类似 tf.cond
的严格模式)。默认为 False
。
• strict (可选): 布尔值。如果为 True
,所有函数的返回值必须具有相同的类型和形状。默认为 False
。
• name (可选): 操作的名称。
函数说明
1. tf.case
会按顺序评估 pred_fn_pairs
中的条件,执行第一个条件为 True
的函数。
2. 如果 exclusive=True
,会检查是否有多个条件为 True
,若有则报错。
3. 如果所有条件为 False
且未提供 default
,会抛出 ValueError
。
4. 如果 strict=True
,所有函数的返回值必须类型和形状一致(用于静态类型检查)。
用法示例
示例 1:基本用法
import tensorflow as tf
# 定义条件和函数
def f1():
return tf.constant(1)
def f2():
return tf.constant(2)
x = tf.constant(3)
y = tf.constant(4)
# 条件判断
pred_fn_pairs = [
(tf.less(x, y), f1), # 如果 x < y,执行 f1
(tf.greater(x, y), f2), # 如果 x > y,执行 f2
]
result = tf.case(pred_fn_pairs, default=lambda: tf.constant(0))
print(result.numpy()) # 输出: 1(因为 3 < 4 为 True)
示例 2:使用字典和默认函数
pred_fn_pairs = {
tf.less(x, y): lambda: tf.constant(10),
tf.greater(x, y): lambda: tf.constant(20),
}
result = tf.case(pred_fn_pairs, default=lambda: tf.constant(99))
print(result.numpy()) # 输出: 10
示例 3:exclusive=True
的严格模式
z = tf.constant(5)
pred_fn_pairs = [
(tf.equal(x, z), lambda: tf.constant(100)), # 3 == 5 为 False
(tf.less(x, z), lambda: tf.constant(200)), # 3 < 5 为 True
]
result = tf.case(pred_fn_pairs, exclusive=True)
print(result.numpy()) # 输出: 200
示例 4:strict=True
强制类型检查
pred_fn_pairs = [
(tf.less(x, y), lambda: tf.constant(1.0)), # 返回 float
(tf.greater(x, y), lambda: tf.constant(2)), # 返回 int(会报错)
]
# 以下会报错,因为 strict=True 要求返回类型一致
# result = tf.case(pred_fn_pairs, strict=True)
注意事项
1. 如果条件是通过 TensorFlow 操作(如 tf.less
)生成的布尔张量,不能直接用 Python 的 if
判断。
2. 函数应是无参数的(通常用 lambda
封装)。
3. 在 @tf.function
装饰的函数中,tf.case
可以与其他 TF 操作无缝结合。
通过 tf.case
可以实现动态分支选择,适用于需要根据张量值决定计算路径的场景。
参考
- https://blog.csdn.net/AI_LX/article/details/89465395