Batches all the inputs tensors to the computation done by the function.
tf.raw_ops.BatchFunction( in_tensors, captured_tensors, f, num_batch_threads, max_batch_size, batch_timeout_micros, Tout, max_enqueued_batches=10, allowed_batch_sizes=[], container='', shared_name='', batching_queue='', enable_large_batch_splitting=False, name=None )
So, for example, in the following code
# This input will be captured. y = tf.placeholder_with_default(1.0, shape=[]) @tf.Defun(tf.float32) def computation(a): return tf.matmul(a, a) + y b = gen_batch_ops.batch_function( f=computation in_tensors=[a], captured_tensors=computation.captured_inputs, Tout=[o.type for o in computation.definition.signature.output_arg], num_batch_threads=1, max_batch_size=10, batch_timeout_micros=100000, # 100ms allowed_batch_sizes=[3, 10], batching_queue="")
If more than one session.run call is simultaneously trying to compute b
the values of a
will be gathered, non-deterministically concatenated along the first axis, and only one thread will run the computation.
Assumes that all arguments of the function are Tensors which will be batched along their first dimension.
Arguments that are captured, are not batched. The session.run call which does the concatenation, will use the values of the captured tensors available to it. Therefore, typical uses of captured tensors should involve values which remain unchanged across session.run calls. Inference is a good example of this.
SparseTensor is not supported. The return value of the decorated function must be a Tensor or a list/tuple of Tensors.
Args | |
---|---|
in_tensors | A list of Tensor objects. The tensors to be batched. |
captured_tensors | A list of Tensor objects. The tensors which are captured in the function, and don't need to be batched. |
f | A function decorated with @Defun. |
num_batch_threads | An int . Number of scheduling threads for processing batches of work. Determines the number of batches processed in parallel. |
max_batch_size | An int . Batch sizes will never be bigger than this. |
batch_timeout_micros | An int . Maximum number of microseconds to wait before outputting an incomplete batch. |
Tout | A list of tf.DTypes that has length >= 1 . the types of the output tensors. |
max_enqueued_batches | An optional int . Defaults to 10 . Maximum number of batches enqueued. Default: 10. |
allowed_batch_sizes | An optional list of ints . Defaults to [] . Optional list of allowed batch sizes. If left empty, does nothing. Otherwise, supplies a list of batch sizes, causing the op to pad batches up to one of those sizes. The entries must increase monotonically. If enable_large_batch_splitting is false (i.e., large-input-split is not enabled) the final entry must equal max_batch_size. |
container | An optional string . Defaults to "" . Controls the scope of sharing of this batch. |
shared_name | An optional string . Defaults to "" . Concurrently running instances of batch in the same device with the same container and shared_name will batch their elements together. If left empty, the op name will be used as the shared name. |
batching_queue | An optional string . Defaults to "" . |
enable_large_batch_splitting | An optional bool . Defaults to False . input with a large size (i.e., larger than the largest value of allowed_batch_sizes ) will be splitted into multiple batches with batch size. |
name | A name for the operation (optional). |
Returns | |
---|---|
A list of Tensor objects of type Tout . |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/raw_ops/BatchFunction