W3cubDocs

/TensorFlow Python

tf.contrib.kfac.fisher_blocks.ConvKFCBasicFB

Class ConvKFCBasicFB

Inherits From: KroneckerProductFB

Defined in tensorflow/contrib/kfac/python/ops/fisher_blocks.py.

FisherBlock for convolutional layers using the basic KFC approx.

Estimates the Fisher Information matrix's blog for a convolutional layer.

Consider a convoluational layer in this model with (unshared) filter matrix 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', this FisherBlock estimates,

$$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T], E[flat(ds) flat(ds)^T])$$

where

$$ds = (d / ds) log p(y | x, w)$$

#locations = number of (x, y) locations where 'w' is applied.

where the expectation is taken over all examples and locations and flat() concatenates an array's leading dimensions.

See equation 23 in https://arxiv.org/abs/1602.01407 for details.

Properties

num_registered_towers

Methods

__init__

__init__(
    layer_collection,
    params,
    padding,
    strides=None,
    dilation_rate=None,
    data_format=None,
    extract_patches_fn=None
)

Creates a ConvKFCBasicFB block.

Args:

  • layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs.
  • params: The parameters (Tensor or tuple of Tensors) of this layer. If kernel alone, a Tensor of shape [..spatial_filter_shape.., in_channels, out_channels]. If kernel and bias, a tuple of 2 elements containing the previous and a Tensor of shape [out_channels].
  • padding: str. Padding method.
  • strides: List of ints or None. Contains [..spatial_filter_strides..] if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_filter_strides, 1].
  • dilation_rate: List of ints or None. Rate for dilation along each spatial dimension if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
  • data_format: str or None. Format of input data.
  • extract_patches_fn: str or None. Name of function that extracts image patches. One of "extract_convolution_patches", "extract_image_patches", "extract_pointwise_conv2d_patches".

full_fisher_block

full_fisher_block()

Explicitly constructs the full Fisher block.

Used for testing purposes. (In general, the result may be very large.)

Returns:

The full Fisher block.

instantiate_factors

instantiate_factors(
    grads_list,
    damping
)

Creates and registers the component factors of this Fisher block.

Args:

  • grads_list: A list gradients (each a Tensor or tuple of Tensors) with respect to the tensors returned by tensors_to_compute_grads() that are to be used to estimate the block.
  • damping: The damping factor (float or Tensor).

multiply

multiply(vector)

Multiplies the vector by the (damped) block.

Args:

  • vector: The vector (a Tensor or tuple of Tensors) to be multiplied.

Returns:

The vector left-multiplied by the (damped) block.

multiply_inverse

multiply_inverse(vector)

Multiplies the vector by the (damped) inverse of the block.

Args:

  • vector: The vector (a Tensor or tuple of Tensors) to be multiplied.

Returns:

The vector left-multiplied by the (damped) inverse of the block.

multiply_matpower

multiply_matpower(
    vector,
    exp
)

Multiplies the vector by the (damped) matrix-power of the block.

Args:

  • vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
  • exp: A float representing the power to raise the block by before multiplying it by the vector.

Returns:

The vector left-multiplied by the (damped) matrix-power of the block.

register_additional_tower

register_additional_tower(
    inputs,
    outputs
)

register_inverse

register_inverse()

Registers a matrix inverse to be computed by the block.

register_matpower

register_matpower(exp)

Registers a matrix power to be computed by the block.

Args:

  • exp: A float representing the power to raise the block by.

tensors_to_compute_grads

tensors_to_compute_grads()

Tensors to compute derivative of loss with respect to.

© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/kfac/fisher_blocks/ConvKFCBasicFB