W3cubDocs

/TensorFlow Python

tf.contrib.distributions.matrix_diag_transform(matrix, transform=None, name=None)

tf.contrib.distributions.matrix_diag_transform(matrix, transform=None, name=None)

See the guide: Statistical Distributions (contrib) > Multivariate distributions

Transform diagonal of [batch-]matrix, leave rest of matrix unchanged.

Create a trainable covariance defined by a Cholesky factor:

# Transform network layer into 2 x 2 array.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))

# Make the diagonal positive.  If the upper triangle was zero, this would be a
# valid Cholesky factor.
chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)

# OperatorPDCholesky ignores the upper triangle.
operator = OperatorPDCholesky(chol)

Example of heteroskedastic 2-D linear regression.

# Get a trainable Cholesky factor.
matrix_values = tf.contrib.layers.fully_connected(activations, 4)
matrix = tf.reshape(matrix_values, (batch_size, 2, 2))
chol = matrix_diag_transform(matrix, transform=tf.nn.softplus)

# Get a trainable mean.
mu = tf.contrib.layers.fully_connected(activations, 2)

# This is a fully trainable multivariate normal!
dist = tf.contrib.distributions.MVNCholesky(mu, chol)

# Standard log loss.  Minimizing this will "train" mu and chol, and then dist
# will be a distribution predicting labels as multivariate Gaussians.
loss = -1 * tf.reduce_mean(dist.log_pdf(labels))

Args:

  • matrix: Rank R Tensor, R >= 2, where the last two dimensions are equal.
  • transform: Element-wise function mapping Tensors to Tensors. To be applied to the diagonal of matrix. If None, matrix is returned unchanged. Defaults to None.
  • name: A name to give created ops. Defaults to "matrix_diag_transform".

Returns:

A Tensor with same shape and dtype as matrix.

Defined in tensorflow/contrib/distributions/python/ops/distribution_util.py.

© 2017 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/distributions/matrix_diag_transform