View source on GitHub |
Initializer base class: all Keras initializers inherit from this class.
Initializers should implement a __call__
method with the following signature:
def __call__(self, shape, dtype=None, **kwargs): # returns a tensor of shape `shape` and dtype `dtype` # containing values drawn from a distribution of your choice.
Optionally, you an also implement the method get_config
and the class method from_config
in order to support serialization -- just like with any Keras object.
Here's a simple example: a random normal initializer.
import tensorflow as tf class ExampleRandomNormal(tf.keras.initializers.Initializer): def __init__(self, mean, stddev): self.mean = mean self.stddev = stddev def __call__(self, shape, dtype=None, **kwargs): return tf.random.normal( shape, mean=self.mean, stddev=self.stddev, dtype=dtype) def get_config(self): # To support serialization return {"mean": self.mean, "stddev": self.stddev}
Note that we don't have to implement from_config
in the example above since the constructor arguments of the class the keys in the config returned by get_config
are the same. In this case, the default from_config
works fine.
from_config
@classmethod from_config( config )
Instantiates an initializer from a configuration dictionary.
initializer = RandomUniform(-1, 1) config = initializer.get_config() initializer = RandomUniform.from_config(config)
Args | |
---|---|
config | A Python dictionary, the output of get_config . |
Returns | |
---|---|
A tf.keras.initializers.Initializer instance. |
get_config
get_config()
Returns the configuration of the initializer as a JSON-serializable dict.
Returns | |
---|---|
A JSON-serializable Python dict. |
__call__
__call__( shape, dtype=None, **kwargs )
Returns a tensor object initialized as specified by the initializer.
Args | |
---|---|
shape | Shape of the tensor. |
dtype | Optional dtype of the tensor. |
**kwargs | Additional keyword arguments. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/keras/initializers/Initializer