Source code for svb.posterior

"""
Definition of the posterior distribution
"""
try:
    import tensorflow.compat.v1 as tf
except ImportError:
    import tensorflow as tf

import numpy as np

from .utils import LogBase
from . import dist

[docs]def get_posterior(idx, param, t, data_model, **kwargs): """ Factory method to return a posterior :param param: svb.parameter.Parameter instance : """ nvertices = data_model.n_vertices initial_mean, initial_var = None, None if param.post_init is not None: initial_mean, initial_var = param.post_init(param, t, data_model.data_flattened) if initial_mean is None: initial_mean = tf.fill([nvertices], param.post_dist.mean) else: initial_mean = param.post_dist.transform.int_values(initial_mean) if initial_var is None: initial_var = tf.fill([nvertices], param.post_dist.var) else: # FIXME variance not value? initial_var = param.post_dist.transform.int_values(initial_var) if param.post_type == "vertexwise" and isinstance(param.post_dist, dist.Normal): return NormalPosterior(idx, initial_mean, initial_var, name=param.name, **kwargs) if param.post_type == "global" and isinstance(param.post_dist, dist.Normal): return GaussianGlobalPosterior(idx, initial_mean, initial_var, name=param.name, **kwargs) raise ValueError("Can't create %s posterior for distribution: %s" % (param.post_type, param.post_dist))
[docs]class Posterior(LogBase): """ Posterior distribution """ def __init__(self, idx, **kwargs): LogBase.__init__(self, **kwargs) self._idx = idx def _get_mean_var(self, mean, var, init_post): if init_post is not None: mean, cov = init_post #if mean.shape[0] != self.nvertices: # raise ValueError("Initializing posterior with %i vertices but input contains %i vertices" % (self.nvertices, mean.shape[0])) if self._idx >= mean.shape[1]: raise ValueError("Initializing posterior for parameter %i but input contains %i parameters" % (self._idx+1, mean.shape[1])) # We have been provided with an initialization posterior. Extract the mean and diagonal of the # covariance and use that as the initial values of the mean and variance. Note that the covariance # initialization is only used if this parameter is embedded in an MVN mean = mean[:, self._idx] var = cov[:, self._idx, self._idx] self.log.info(" - Initializing posterior mean and variance from input posterior") self.log.info(" means=%s", np.mean(mean)) self.log.info(" vars=%s", np.mean(var)) return mean, var
[docs] def sample(self, nsamples): """ :param nsamples: Number of samples to return per parameter vertex / parameter :return: A tensor of shape [W, P, S] where W is the number of parameter vertices, P is the number of parameters in the distribution (possibly 1) and S is the number of samples """ raise NotImplementedError()
[docs] def entropy(self, samples=None): """ :param samples: A tensor of shape [W, P, S] where W is the number of parameter vertices, P is the number of parameters in the prior (possibly 1) and S is the number of samples. This parameter may or may not be used in the calculation. If it is required, the implementation class must check that it is provided :return Tensor of shape [W] containing vertexwise distribution entropy """ raise NotImplementedError()
[docs] def state(self): """ :return Sequence of tf.Tensor objects containing the state of all variables in this posterior. The tensors returned will be evaluated to create a savable state which may then be passed back into set_state() """ raise NotImplementedError()
[docs] def set_state(self, state): """ :param state: State of variables in this posterior, as returned by previous call to state() :return Sequence of tf.Operation objects containing which will set the variables in this posterior to the specified state """ raise NotImplementedError()
def log_det_cov(self): raise NotImplementedError()
[docs]class NormalPosterior(Posterior): """ Posterior distribution for a single vertexwise parameter with a normal distribution """ def __init__(self, idx, mean, var, **kwargs): """ :param mean: Tensor of shape [W] containing the mean at each parameter vertex :param var: Tensor of shape [W] containing the variance at each parameter vertex """ Posterior.__init__(self, idx, **kwargs) self.nvertices = tf.shape(mean)[0] self.name = kwargs.get("name", "NormPost") mean, var = self._get_mean_var(mean, var, kwargs.get("init", None)) mean = tf.cast(mean, tf.float32) var = tf.cast(var, tf.float32) mean = self.log_tf(tf.where(tf.is_finite(mean), mean, tf.zeros_like(mean))) var = tf.where(tf.is_nan(var), tf.ones_like(var), var) self.mean_variable = self.log_tf(tf.Variable(mean, validate_shape=False, name="%s_mean" % self.name)) self.log_var = self.log_tf(tf.Variable(tf.log(var), validate_shape=False, name="%s_log_var" % self.name)) self.var_variable = self.log_tf(tf.exp(self.log_var, name="%s_var" % self.name)) if kwargs.get("suppress_nan", True): #self.mean = tf.where(tf.is_nan(self.mean_variable), tf.ones_like(self.mean_variable), self.mean_variable) #self.var = tf.where(tf.is_nan(self.var_variable), tf.ones_like(self.var_variable), self.var_variable) self.mean = tf.where(tf.is_nan(self.mean_variable), mean, self.mean_variable) self.var = tf.where(tf.is_nan(self.var_variable), var, self.var_variable) else: self.mean = self.mean_variable self.var = self.var_variable self.std = self.log_tf(tf.sqrt(self.var, name="%s_std" % self.name))
[docs] def sample(self, nsamples): eps = tf.random_normal((self.nvertices, 1, nsamples), 0, 1, dtype=tf.float32) tiled_mean = tf.tile(tf.reshape(self.mean, [self.nvertices, 1, 1]), [1, 1, nsamples]) sample = self.log_tf(tf.add(tiled_mean, tf.multiply(tf.reshape(self.std, [self.nvertices, 1, 1]), eps), name="%s_sample" % self.name)) return sample
[docs] def entropy(self, _samples=None): entropy = tf.identity(-0.5 * tf.log(self.var), name="%s_entropy" % self.name) return self.log_tf(entropy)
[docs] def state(self): return [self.mean, self.log_var]
[docs] def set_state(self, state): return [ tf.assign(self.mean_variable, state[0]), tf.assign(self.log_var, state[1]) ]
def __str__(self): return "Vertexwise posterior"
[docs]class GaussianGlobalPosterior(Posterior): """ Posterior which has the same value at every parameter vertex """ def __init__(self, idx, mean, var, **kwargs): """ :param mean: Tensor of shape [W] containing the mean at each parameter vertex :param var: Tensor of shape [W] containing the variance at each parameter vertex """ Posterior.__init__(self, idx, **kwargs) self.nvertices = tf.shape(mean)[0] self.name = kwargs.get("name", "GlobalPost") mean, var = self._get_mean_var(mean, var, kwargs.get("init", None)) # Take the mean of the mean and variance across vertices as the initial value # in case there is a vertexwise initialization function initial_mean_global = tf.reshape(tf.reduce_mean(mean), [1]) initial_var_global = tf.reshape(tf.reduce_mean(var), [1]) self.mean_variable = tf.Variable(initial_mean_global, dtype=tf.float32, validate_shape=False, name="%s_mean" % self.name) self.log_var = tf.Variable(tf.log(tf.cast(initial_var_global, dtype=tf.float32)), validate_shape=False, name="%s_log_var" % self.name) self.var_variable = self.log_tf(tf.exp(self.log_var, name="%s_var" % self.name)) if kwargs.get("suppress_nan", True): self.mean_global = tf.where(tf.is_nan(self.mean_variable), initial_mean_global, self.mean_variable) self.var_global = tf.where(tf.is_nan(self.var_variable), initial_var_global, self.var_variable) else: self.mean_global = self.mean_variable self.var_global = self.var_variable self.mean = self.log_tf(tf.tile(self.mean_global, [self.nvertices]), name="%s_meang" % self.name) self.var = tf.tile(self.var_global, [self.nvertices]) self.std = self.log_tf(tf.sqrt(self.var, name="%s_std" % self.name))
[docs] def sample(self, nsamples): """ FIXME should each parameter vertex get the same sample? Currently YES """ eps = tf.random_normal((1, 1, nsamples), 0, 1, dtype=tf.float32) tiled_mean = tf.tile(tf.reshape(self.mean, [self.nvertices, 1, 1]), [1, 1, nsamples]) sample = self.log_tf(tf.add(tiled_mean, tf.multiply(tf.reshape(self.std, [self.nvertices, 1, 1]), eps), name="%s_sample" % self.name)) return sample
[docs] def entropy(self, _samples=None): entropy = tf.identity(-0.5 * tf.log(self.var), name="%s_entropy" % self.name) return self.log_tf(entropy)
[docs] def state(self): return [self.mean_global, self.log_var]
[docs] def set_state(self, state): return [ tf.assign(self.mean_variable, state[0]), tf.assign(self.log_var, state[1]) ]
def __str__(self): return "Global posterior"
[docs]class FactorisedPosterior(Posterior): """ Posterior distribution for a set of parameters with no covariance """ def __init__(self, posts, **kwargs): Posterior.__init__(self, -1, **kwargs) self.posts = posts self.nparams = len(self.posts) self.name = kwargs.get("name", "FactPost") means = [post.mean for post in self.posts] variances = [post.var for post in self.posts] mean = tf.stack(means, axis=-1, name="%s_mean" % self.name) var = tf.stack(variances, axis=-1, name="%s_var" % self.name) self.mean = self.log_tf(tf.identity(mean, name="%s_mean" % self.name)) self.var = self.log_tf(tf.identity(var, name="%s_var" % self.name)) self.std = tf.sqrt(self.var, name="%s_std" % self.name) self.nvertices = posts[0].nvertices # Covariance matrix is diagonal self.cov = tf.matrix_diag(self.var, name='%s_cov' % self.name) # Regularisation to make sure cov is invertible. Note that we do not # need this for a diagonal covariance matrix but it is useful for # the full MVN covariance which shares some of the calculations self.cov_reg = 1e-5*tf.eye(self.nparams)
[docs] def sample(self, nsamples): samples = [post.sample(nsamples) for post in self.posts] sample = tf.concat(samples, axis=1, name="%s_sample" % self.name) return self.log_tf(sample)
[docs] def entropy(self, _samples=None): entropy = tf.zeros([self.nvertices], dtype=tf.float32) for post in self.posts: entropy = tf.add(entropy, post.entropy(), name="%s_entropy" % self.name) return self.log_tf(entropy)
[docs] def state(self): state = [] for post in self.posts: state.extend(post.state()) return state
[docs] def set_state(self, state): ops = [] for idx, post in enumerate(self.posts): ops += post.set_state(state[idx*2:idx*2+2]) return ops
[docs] def log_det_cov(self): """ Determinant of diagonal matrix is product of diagonal entries """ return tf.reduce_sum(tf.log(self.var), axis=1, name='%s_log_det_cov' % self.name)
[docs] def latent_loss(self, prior): """ Analytic expression for latent loss which can be used when posterior and prior are Gaussian https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence :param prior: Vertexwise Prior instance which defines the ``mean`` and ``cov`` vertices attributes """ prior_cov_inv = tf.matrix_inverse(prior.cov) mean_diff = tf.subtract(self.mean, prior.mean) term1 = tf.trace(tf.matmul(prior_cov_inv, self.cov)) term2 = tf.matmul(tf.reshape(mean_diff, (self.nvertices, 1, -1)), prior_cov_inv) term3 = tf.reshape(tf.matmul(term2, tf.reshape(mean_diff, (self.nvertices, -1, 1))), [self.nvertices]) term4 = prior.log_det_cov() term5 = self.log_det_cov() return self.log_tf(tf.identity(0.5*(term1 + term3 - self.nparams + term4 - term5), name="%s_latent_loss" % self.name))
[docs]class MVNPosterior(FactorisedPosterior): """ Multivariate Normal posterior distribution """ def __init__(self, posts, **kwargs): FactorisedPosterior.__init__(self, posts, **kwargs) # The full covariance matrix is formed from the Cholesky decomposition # to ensure that it remains positive definite. # # To achieve this, we have to create PxP tensor variables for # each parameter vertex, but we then extract only the lower triangular # elements and train only on these. The diagonal elements # are constructed by the FactorisedPosterior if kwargs.get("init", None): # We are initializing from an existing posterior. # The FactorizedPosterior will already have extracted the mean and # diagonal of the covariance matrix - we need the Cholesky decomposition # of the covariance to initialize the off-diagonal terms self.log.info(" - Initializing posterior covariance from input posterior") _mean, cov = kwargs["init"] covar_init = tf.cholesky(cov) else: covar_init = tf.zeros([self.nvertices, self.nparams, self.nparams], dtype=tf.float32) self.off_diag_vars_base = self.log_tf(tf.Variable(covar_init, validate_shape=False, name='%s_off_diag_vars' % self.name)) if kwargs.get("suppress_nan", True): self.off_diag_vars = tf.where(tf.is_nan(self.off_diag_vars_base), tf.zeros_like(self.off_diag_vars_base), self.off_diag_vars_base) else: self.off_diag_vars = self.off_diag_vars_base self.off_diag_cov_chol = tf.matrix_set_diag(tf.matrix_band_part(self.off_diag_vars, -1, 0), tf.zeros([self.nvertices, self.nparams]), name='%s_off_diag_cov_chol' % self.name) # Combine diagonal and off-diagonal elements into full matrix self.cov_chol = tf.add(tf.matrix_diag(self.std), self.off_diag_cov_chol, name='%s_cov_chol' % self.name) # Form the covariance matrix from the chol decomposition self.cov = tf.matmul(tf.transpose(self.cov_chol, perm=(0, 2, 1)), self.cov_chol, name='%s_cov' % self.name) self.cov_chol = self.log_tf(self.cov_chol) self.cov = self.log_tf(self.cov)
[docs] def log_det_cov(self): """ Determinant of a matrix can be calculated from the Cholesky decomposition which may be faster and more stable than tf.matrix_determinant """ return self.log_tf(tf.multiply(2.0, tf.reduce_sum(tf.log(tf.matrix_diag_part(self.cov_chol)), axis=1), name="%s_det_cov" % self.name))
[docs] def sample(self, nsamples): # Use the 'reparameterization trick' to return the samples eps = tf.random_normal((self.nvertices, self.nparams, nsamples), 0, 1, dtype=tf.float32, name="eps") # NB self.cov_chol is the Cholesky decomposition of the covariance matrix # so plays the role of the std.dev. tiled_mean = tf.tile(tf.reshape(self.mean, [self.nvertices, self.nparams, 1]), [1, 1, nsamples]) sample = tf.add(tiled_mean, tf.matmul(self.cov_chol, eps), name="%s_sample" % self.name) return self.log_tf(sample)
[docs] def entropy(self, _samples=None): entropy = tf.identity(-0.5 * self.log_det_cov(), name="%s_entropy" % self.name) return self.log_tf(entropy)
[docs] def state(self): return list(FactorisedPosterior.state(self)) + [self.off_diag_vars]
[docs] def set_state(self, state): ops = list(FactorisedPosterior.set_state(self, state[:-1])) ops += [tf.assign(self.off_diag_vars_base, state[-1], validate_shape=False)] return ops