View source on GitHub |
Create a switch/case operation, i.e. an integer-indexed conditional.
tf.switch_case( branch_index, branch_fns, default=None, name='switch_case' )
See also tf.case
.
This op can be substantially more efficient than tf.case
when exactly one branch will be selected. tf.switch_case
is more like a C++ switch/case statement than tf.case
, which is more like an if/elif/elif/else chain.
The branch_fns
parameter is either a dict from int
to callables, or list of (int
, callable) pairs, or simply a list of callables (in which case the index is implicitly the key). The branch_index
Tensor
is used to select an element in branch_fns
with matching int
key, falling back to default
if none match, or max(keys)
if no default
is provided. The keys must form a contiguous set from 0
to len(branch_fns) - 1
.
tf.switch_case
supports nested structures as implemented in tf.nest
. All callables must return the same (possibly nested) value structure of lists, tuples, and/or named tuples.
Example:
switch (branch_index) { // c-style switch case 0: return 17; case 1: return 31; default: return -1; }
or
branches = {0: lambda: 17, 1: lambda: 31} branches.get(branch_index, lambda: -1)()
def f1(): return tf.constant(17) def f2(): return tf.constant(31) def f3(): return tf.constant(-1) r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3) # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
Args | |
---|---|
branch_index | An int Tensor specifying which of branch_fns should be executed. |
branch_fns | A dict mapping int s to callables, or a list of (int , callable) pairs, or simply a list of callables (in which case the index serves as the key). Each callable must return a matching structure of tensors. |
default | Optional callable that returns a structure of tensors. |
name | A name for this operation (optional). |
Returns | |
---|---|
The tensors returned by the callable identified by branch_index , or those returned by default if no key matches and default was provided, or those returned by the max-keyed branch_fn if no default is provided. |
Raises | |
---|---|
TypeError | If branch_fns is not a list/dictionary. |
TypeError | If branch_fns is a list but does not contain 2-tuples or callables. |
TypeError | If fns[i] is not callable for any i, or default is not callable. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/switch_case