Operator adding dropout to inputs and outputs of the given cell.

tf.compat.v1.nn.rnn_cell.DropoutWrapper( *args, **kwargs )

Args | |
---|---|

`cell` | an RNNCell, a projection to output_size is added to it. |

`input_keep_prob` | unit Tensor or float between 0 and 1, input keep probability; if it is constant and 1, no input dropout will be added. |

`output_keep_prob` | unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added. |

`state_keep_prob` | unit Tensor or float between 0 and 1, output keep probability; if it is constant and 1, no output dropout will be added. State dropout is performed on the outgoing states of the cell. Note the state components to which dropout is applied when `state_keep_prob` is in `(0, 1)` are also determined by the argument `dropout_state_filter_visitor` (e.g. by default dropout is never applied to the `c` component of an `LSTMStateTuple` ). |

`variational_recurrent` | Python bool. If `True` , then the same dropout pattern is applied across all time steps per run call. If this parameter is set, `input_size` must be provided. |

`input_size` | (optional) (possibly nested tuple of) `TensorShape` objects containing the depth(s) of the input tensors expected to be passed in to the `DropoutWrapper` . Required and used iff `variational_recurrent = True` and `input_keep_prob < 1` . |

`dtype` | (optional) The `dtype` of the input, state, and output tensors. Required and used iff `variational_recurrent = True` . |

`seed` | (optional) integer, the randomness seed. |

`dropout_state_filter_visitor` | (optional), default: (see below). Function that takes any hierarchical level of the state and returns a scalar or depth=1 structure of Python booleans describing which terms in the state should be dropped out. In addition, if the function returns `True` , dropout is applied across this sublevel. If the function returns `False` , dropout is not applied across this entire sublevel. Default behavior: perform dropout on all terms except the memory (`c` ) state of `LSTMCellState` objects, and don't try to apply dropout to `TensorArray` objects: `def dropout_state_filter_visitor(s): if isinstance(s, LSTMCellState): # Never perform dropout on the c state. return LSTMCellState(c=False, h=True) elif isinstance(s, TensorArray): return False return True` |

`**kwargs` | dict of keyword arguments for base layer. |

Raises | |
---|---|

`TypeError` | if `cell` is not an `RNNCell` , or `keep_state_fn` is provided but not `callable` . |

`ValueError` | if any of the keep_probs are not between 0 and 1. |

Attributes | |
---|---|

`graph` | DEPRECATED FUNCTION |

`output_size` | |

`scope_name` | |

`state_size` | |

`wrapped_cell` |

`get_initial_state`

get_initial_state( inputs=None, batch_size=None, dtype=None )

`zero_state`

zero_state( batch_size, dtype )

© 2020 The TensorFlow Authors. All rights reserved.

Licensed under the Creative Commons Attribution License 3.0.

Code samples licensed under the Apache 2.0 License.

https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/compat/v1/nn/rnn_cell/DropoutWrapper