Sets the global dtype policy.
tf.keras.mixed_precision.set_global_policy( policy )
The global policy is the default tf.keras.mixed_precision.Policy
used for layers, if no policy is passed to the layer constructor.
tf.keras.mixed_precision.set_global_policy('mixed_float16') tf.keras.mixed_precision.global_policy() <Policy "mixed_float16"> tf.keras.layers.Dense(10).dtype_policy <Policy "mixed_float16"> # Global policy is not used if a policy is directly passed to constructor tf.keras.layers.Dense(10, dtype='float64').dtype_policy <Policy "float64"> tf.keras.mixed_precision.set_global_policy('float32')
If no global policy is set, layers will instead default to a Policy constructed from tf.keras.backend.floatx()
.
To use mixed precision, the global policy should be set to 'mixed_float16'
or 'mixed_bfloat16'
, so that every layer uses a 16-bit compute dtype and float32 variable dtype by default.
Only floating point policies can be set as the global policy, such as 'float32'
and 'mixed_float16'
. Non-floating point policies such as 'int32'
and 'complex64'
cannot be set as the global policy because most layers do not support such policies.
See tf.keras.mixed_precision.Policy
for more information.
Args | |
---|---|
policy | A Policy, or a string that will be converted to a Policy. Can also be None, in which case the global policy will be constructed from tf.keras.backend.floatx() |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/keras/mixed_precision/set_global_policy