tf.case 是 TensorFlow 中的一个条件操作,用于根据多个条件选择执行不同的函数。它类似于 Python 中的 if-elif-else 结构,但能够在 TensorFlow 的计算图中动态选择分支。

参数说明 tf.case 的主要参数如下:

tf.case(
    pred_fn_pairs,
    default=None,
    exclusive=False,
    strict=False,
    name='case'
)

• pred_fn_pairs: 一个由 (条件, 函数) 对组成的列表或字典。条件是一个布尔张量,函数是一个可调用的对象(通常是无参数的 lambda 或函数)。tf.case 会执行第一个条件为 True 的函数。

• default (可选): 默认函数,当所有条件为 False 时执行。如果未提供且所有条件为 False,会报错。

• exclusive (可选): 布尔值。如果为 True,则确保只有一个条件为 True(类似 tf.cond 的严格模式)。默认为 False

• strict (可选): 布尔值。如果为 True,所有函数的返回值必须具有相同的类型和形状。默认为 False

• name (可选): 操作的名称。


函数说明 1. tf.case 会按顺序评估 pred_fn_pairs 中的条件,执行第一个条件为 True 的函数。 2. 如果 exclusive=True,会检查是否有多个条件为 True,若有则报错。 3. 如果所有条件为 False 且未提供 default,会抛出 ValueError。 4. 如果 strict=True,所有函数的返回值必须类型和形状一致(用于静态类型检查)。


用法示例

示例 1:基本用法

import tensorflow as tf

# 定义条件和函数
def f1():
    return tf.constant(1)

def f2():
    return tf.constant(2)

x = tf.constant(3)
y = tf.constant(4)

# 条件判断
pred_fn_pairs = [
    (tf.less(x, y), f1),  # 如果 x < y,执行 f1
    (tf.greater(x, y), f2),  # 如果 x > y,执行 f2
]

result = tf.case(pred_fn_pairs, default=lambda: tf.constant(0))
print(result.numpy())  # 输出: 1(因为 3 < 4 为 True)

示例 2:使用字典和默认函数

pred_fn_pairs = {
    tf.less(x, y): lambda: tf.constant(10),
    tf.greater(x, y): lambda: tf.constant(20),
}

result = tf.case(pred_fn_pairs, default=lambda: tf.constant(99))
print(result.numpy())  # 输出: 10

示例 3:exclusive=True 的严格模式

z = tf.constant(5)
pred_fn_pairs = [
    (tf.equal(x, z), lambda: tf.constant(100)),  # 3 == 5 为 False
    (tf.less(x, z), lambda: tf.constant(200)),   # 3 < 5 为 True
]

result = tf.case(pred_fn_pairs, exclusive=True)
print(result.numpy())  # 输出: 200

示例 4:strict=True 强制类型检查

pred_fn_pairs = [
    (tf.less(x, y), lambda: tf.constant(1.0)),  # 返回 float
    (tf.greater(x, y), lambda: tf.constant(2)),  # 返回 int(会报错)
]

# 以下会报错,因为 strict=True 要求返回类型一致
# result = tf.case(pred_fn_pairs, strict=True)


注意事项 1. 如果条件是通过 TensorFlow 操作(如 tf.less)生成的布尔张量,不能直接用 Python 的 if 判断。 2. 函数应是无参数的(通常用 lambda 封装)。 3. 在 @tf.function 装饰的函数中,tf.case 可以与其他 TF 操作无缝结合。

通过 tf.case 可以实现动态分支选择,适用于需要根据张量值决定计算路径的场景。

参考

  • https://blog.csdn.net/AI_LX/article/details/89465395