Source code for svb.prior

"""
Definition of prior distribution
"""
import numpy as np

try:
    import tensorflow.compat.v1 as tf
except ImportError:
    import tensorflow as tf

from .utils import LogBase
from .dist import Normal

PRIOR_TYPE_NONSPATIAL = "N"
PRIOR_TYPE_SPATIAL_MRF = "M"

[docs]def get_prior(param, data_model, **kwargs): """ Factory method to return a vertexwise prior """ prior = None if isinstance(param.prior_dist, Normal): if param.prior_type == "N": prior = NormalPrior(data_model.n_vertices, param.prior_dist.mean, param.prior_dist.var, **kwargs) elif param.prior_type == "M": prior = MRFSpatialPrior(data_model.n_vertices, param.prior_dist.mean, param.prior_dist.var, **kwargs) elif param.prior_type == "M2": prior = MRF2SpatialPrior(data_model.n_vertices, param.prior_dist.mean, param.prior_dist.var, **kwargs) elif param.prior_type == "Mfab": prior = FabberMRFSpatialPrior(data_model.n_vertices, param.prior_dist.mean, param.prior_dist.var, **kwargs) elif param.prior_type == "A": prior = ARDPrior(data_model.n_vertices, param.prior_dist.mean, param.prior_dist.var, **kwargs) if prior is not None: return prior else: raise ValueError("Can't create prior type %s for distribution %s - unrecognized combination" % (param.prior_type, param.prior_dist))
[docs]class Prior(LogBase): """ Base class for a prior, defining methods that must be implemented """
[docs] def mean_log_pdf(self, samples): """ :param samples: A tensor of shape [W, P, S] where W is the number of parameter vertices, P is the number of parameters in the prior (possibly 1) and S is the number of samples :return: A tensor of shape [W] where W is the number of parameter vertices containing the mean log PDF of the parameter samples provided """ raise NotImplementedError()
def log_det_cov(self): raise NotImplementedError()
[docs]class NormalPrior(Prior): """ Prior based on a vertexwise univariate normal distribution """ def __init__(self, nvertices, mean, var, **kwargs): """ :param mean: Prior mean value :param var: Prior variance """ Prior.__init__(self) self.name = kwargs.get("name", "NormalPrior") self.nvertices = nvertices self.scalar_mean = mean self.scalar_var = var self.mean = tf.fill([nvertices], mean, name="%s_mean" % self.name) self.var = tf.fill([nvertices], var, name="%s_var" % self.name) self.std = tf.sqrt(self.var, name="%s_std" % self.name)
[docs] def mean_log_pdf(self, samples): """ Mean log PDF for normal distribution Note that ``term1`` is a constant offset when the prior variance is fixed and hence in earlier versions of the code this was neglected, along with other constant offsets such as factors of pi. However when this code is inherited by spatial priors and ARD the variance is no longer fixed and this term must be included. """ dx = tf.subtract(samples, tf.reshape(self.mean, [self.nvertices, 1, 1])) # [W, 1, N] z = tf.div(tf.square(dx), tf.reshape(self.var, [self.nvertices, 1, 1])) # [W, 1, N] term1 = self.log_tf(-0.5*tf.log(tf.reshape(self.var, [self.nvertices, 1, 1])), name="term1") term2 = self.log_tf(-0.5*z, name="term2") log_pdf = term1 + term2 # [W, 1, N] mean_log_pdf = tf.reshape(tf.reduce_mean(log_pdf, axis=-1), [self.nvertices]) # [W] return mean_log_pdf
def __str__(self): return "Non-spatial prior (%f, %f)" % (self.scalar_mean, self.scalar_var)
[docs]class FabberMRFSpatialPrior(NormalPrior): """ Prior designed to mimic the 'M' type spatial prior in Fabber. Note that this uses update equations for ak which is not in the spirit of the stochastic method. 'Native' SVB MRF spatial priors are also defined which simply treat the spatial precision parameter as an inference variable. This code has been verified to generate the same ak estimate given the same input as Fabber, however in practice it does not optimize to the same value. We don't yet know why. """ def __init__(self, nvertices, mean, var, idx=None, post=None, nn=None, n2=None, **kwargs): """ :param mean: Tensor of shape [W] containing the prior mean at each parameter vertex :param var: Tensor of shape [W] containing the prior variance at each parameter vertex :param post: Posterior instance :param nn: Sparse tensor of shape [W, W] containing nearest neighbour lists :param n2: Sparse tensor of shape [W, W] containing second nearest neighbour lists """ NormalPrior.__init__(self, nvertices, mean, var, name="FabberMRFSpatialPrior") self.idx = idx # Save the original vertexwise mean and variance - the actual prior mean/var # will be calculated from these and also the spatial variation in neighbour vertices self.fixed_mean = self.mean self.fixed_var = self.var # nn and n2 are sparse tensors of shape [W, W]. If nn[A, B] = 1 then A is # a nearest neighbour of B, and similarly for n2 and second nearest neighbours self.nn = nn self.n2 = n2 # Set up spatial smoothing parameter calculation from posterior and neighbour lists self._setup_ak(post, nn, n2) # Set up prior mean/variance self._setup_mean_var(post, nn, n2) def __str__(self): return "Spatial MRF prior (%f, %f)" % (self.scalar_mean, self.scalar_var) def _setup_ak(self, post, nn, n2): # This is the equivalent of CalculateAk in Fabber # # Some of this could probably be better done using linalg # operations but bear in mind this is one parameter only self.sigmaK = self.log_tf(tf.matrix_diag_part(post.cov)[:, self.idx], name="sigmak") # [W] self.wK = self.log_tf(post.mean[:, self.idx], name="wk") # [W] self.num_nn = self.log_tf(tf.sparse_reduce_sum(self.nn, axis=1), name="num_nn") # [W] # Sum over vertices of parameter variance multiplied by number of # nearest neighbours for each vertex trace_term = self.log_tf(tf.reduce_sum(self.sigmaK * self.num_nn), name="trace") # [1] # Sum of nearest and next-nearest neighbour mean values self.sum_means_nn = self.log_tf(tf.reshape(tf.sparse_tensor_dense_matmul(self.nn, tf.reshape(self.wK, (-1, 1))), (-1,)), name="wksum") # [W] self.sum_means_n2 = self.log_tf(tf.reshape(tf.sparse_tensor_dense_matmul(self.n2, tf.reshape(self.wK, (-1, 1))), (-1,)), name="contrib8") # [W] # vertex parameter mean multipled by number of nearest neighbours wknn = self.log_tf(self.wK * self.num_nn, name="wknn") # [W] swk = self.log_tf(wknn - self.sum_means_nn, name="swk") # [W] term2 = self.log_tf(tf.reduce_sum(swk * self.wK), name="term2") # [1] gk = 1 / (0.5 * trace_term + 0.5 * term2 + 0.1) hk = tf.multiply(tf.to_float(self.nvertices), 0.5) + 1.0 self.ak = self.log_tf(tf.identity(gk * hk, name="ak")) def _setup_mean_var(self, post, nn, n2): # This is the equivalent of ApplyToMVN in Fabber contrib_nn = self.log_tf(8*self.sum_means_nn, name="contrib_nn") # [W] contrib_n2 = self.log_tf(-self.sum_means_n2, name="contrib_n2") # [W] spatial_mean = self.log_tf(contrib_nn / (8*self.num_nn), name="spatial_mean") spatial_prec = self.log_tf(self.num_nn * self.ak, name="spatial_prec") self.var = self.log_tf(1 / (1/self.fixed_var + spatial_prec), name="%s_var" % self.name) #self.var = self.fixed_var self.mean = self.log_tf(self.var * spatial_prec * spatial_mean, name="%s_mean" % self.name)
#self.mean = self.fixed_mean + self.ak
[docs]class MRFSpatialPrior(Prior): """ Prior which performs adaptive spatial regularization based on the contents of neighbouring vertices using the Markov Random Field method This uses the same formalism as the Fabber 'M' type spatial prior but treats the ak as a parameter of the optimization. """ def __init__(self, nvertices, mean, var, idx=None, post=None, nn=None, n2=None, **kwargs): Prior.__init__(self) self.name = kwargs.get("name", "MRFSpatialPrior") self.nvertices = nvertices self.mean = tf.fill([nvertices], mean, name="%s_mean" % self.name) self.var = tf.fill([nvertices], var, name="%s_var" % self.name) self.std = tf.sqrt(self.var, name="%s_std" % self.name) # nn is a sparse tensor of shape [W, W]. If nn[A, B] = 1 then A is # a nearest neighbour of B self.nn = nn # Set up spatial smoothing parameter calculation from posterior and neighbour lists # We infer the log of ak. self.logak = tf.Variable(-5.0, name="log_ak", dtype=tf.float32) self.ak = self.log_tf(tf.exp(self.logak, name="ak"))
[docs] def mean_log_pdf(self, samples): r""" mean log PDF for the MRF spatial prior. This is calculating: :math:`\log P = \frac{1}{2} \log \phi - \frac{\phi}{2}\underline{x^T} D \underline{x}` """ samples = tf.reshape(samples, (self.nvertices, -1)) # [W, N] self.num_nn = self.log_tf(tf.sparse_reduce_sum(self.nn, axis=1), name="num_nn") # [W] dx_diag = self.log_tf(tf.reshape(self.num_nn, (self.nvertices, 1)) * samples, name="dx_diag") # [W, N] dx_offdiag = self.log_tf(tf.sparse_tensor_dense_matmul(self.nn, samples), name="dx_offdiag") # [W, N] self.dx = self.log_tf(dx_diag - dx_offdiag, name="dx") # [W, N] self.xdx = self.log_tf(samples * self.dx, name="xdx") # [W, N] term1 = tf.identity(0.5*self.logak, name="term1") term2 = tf.identity(-0.5*self.ak*self.xdx, name="term2") log_pdf = term1 + term2 # [W, N] mean_log_pdf = tf.reshape(tf.reduce_mean(log_pdf, axis=-1), [self.nvertices]) # [W] # Gamma prior if we care #q1, q2 = 1, 100 #mean_log_pdf += (q1-1) * self.logak - self.ak / q2 return mean_log_pdf
def __str__(self): return "MRF spatial prior"
[docs]class ARDPrior(NormalPrior): """ Automatic Relevance Determination prior """ def __init__(self, nvertices, mean, var, **kwargs): NormalPrior.__init__(self, nvertices, mean, var, **kwargs) self.name = kwargs.get("name", "ARDPrior") self.fixed_var = self.var # Set up inferred precision parameter phi self.logphi = tf.Variable(tf.log(1/self.fixed_var), name="log_phi", dtype=tf.float32) self.phi = self.log_tf(tf.exp(self.logphi, name="phi")) self.var = 1/self.phi self.std = tf.sqrt(self.var, name="%s_std" % self.name) def __str__(self): return "ARD prior"
[docs]class MRF2SpatialPrior(Prior): """ Prior which performs adaptive spatial regularization based on the contents of neighbouring vertices using the Markov Random Field method This uses the same formalism as the Fabber 'M' type spatial prior but treats the ak as a parameter of the optimization. It differs from MRFSpatialPrior by using the PDF formulation of the PDF rather than the matrix formulation (the two are equivalent but currently we keep both around for checking that they really are!) FIXME currently this does not work unless sample size=1 """ def __init__(self, nvertices, mean, var, idx=None, post=None, nn=None, n2=None, **kwargs): Prior.__init__(self) self.name = kwargs.get("name", "MRF2SpatialPrior") self.nvertices = nvertices self.mean = tf.fill([nvertices], mean, name="%s_mean" % self.name) self.var = tf.fill([nvertices], var, name="%s_var" % self.name) self.std = tf.sqrt(self.var, name="%s_std" % self.name) # nn is a sparse tensor of shape [W, W]. If nn[A, B] = 1 then A is # a nearest neighbour of B self.nn = nn # We need the number of samples to implement the log PDF function self.sample_size = kwargs.get("sample_size", 5) # Set up spatial smoothing parameter calculation from posterior and neighbour lists self.logak = tf.Variable(-5.0, name="log_ak", dtype=tf.float32) self.ak = self.log_tf(tf.exp(self.logak, name="ak"))
[docs] def mean_log_pdf(self, samples): samples = tf.reshape(samples, (self.nvertices, -1)) # [W, N] self.num_nn = self.log_tf(tf.sparse_reduce_sum(self.nn, axis=1), name="num_nn") # [W] expanded_nn = tf.sparse_concat(2, [tf.sparse.reshape(self.nn, (self.nvertices, self.nvertices, 1))] * self.sample_size) xj = expanded_nn * tf.reshape(samples, (self.nvertices, 1, -1)) #xi = tf.reshape(tf.sparse.to_dense(tf.sparse.reorder(self.nn)), (self.nvertices, self.nvertices, 1)) * tf.reshape(samples, (1, self.nvertices, -1)) xi = expanded_nn * tf.reshape(samples, (1, self.nvertices, -1)) #xi = tf.sparse.transpose(xj, perm=(1, 0, 2)) neg_xi = tf.SparseTensor(xi.indices, -xi.values, dense_shape=xi.dense_shape ) dx2 = tf.square(tf.sparse.add(xj, neg_xi), name="dx2") sdx = tf.sparse.reduce_sum(dx2, axis=0) # [W, N] term1 = tf.identity(0.5*self.logak, name="term1") term2 = tf.identity(-self.ak * sdx / 4, name="term2") log_pdf = term1 + term2 # [W, N] mean_log_pdf = tf.reshape(tf.reduce_mean(log_pdf, axis=-1), [self.nvertices]) # [W] return mean_log_pdf
def __str__(self): return "MRF2 spatial prior"
[docs]class ConstantMRFSpatialPrior(NormalPrior): """ Prior which performs adaptive spatial regularization based on the contents of neighbouring vertices using the Markov Random Field method This is equivalent to the Fabber 'M' type spatial prior """ def __init__(self, nvertices, mean, var, idx=None, nn=None, n2=None, **kwargs): """ :param mean: Tensor of shape [W] containing the prior mean at each parameter vertex :param var: Tensor of shape [W] containing the prior variance at each parameter vertex :param post: Posterior instance :param nn: Sparse tensor of shape [W, W] containing nearest neighbour lists :param n2: Sparse tensor of shape [W, W] containing second nearest neighbour lists """ NormalPrior.__init__(self, nvertices, mean, var, name="MRFSpatialPrior") self.idx = idx # Save the original vertexwise mean and variance - the actual prior mean/var # will be calculated from these and also the spatial variation in neighbour vertices self.fixed_mean = self.mean self.fixed_var = self.var # nn and n2 are sparse tensors of shape [W, W]. If nn[A, B] = 1 then A is # a nearest neighbour of B, and similarly for n2 and second nearest neighbours self.nn = nn self.n2 = n2 def __str__(self): return "Spatial MRF prior (%f, %f) - const" % (self.scalar_mean, self.scalar_var) def update_ak(self, post_mean, post_cov): # This is the equivalent of CalculateAk in Fabber # # Some of this could probably be better done using linalg # operations but bear in mind this is one parameter only self.sigmaK = post_cov[:, self.idx, self.idx] # [W] self.wK = post_mean[:, self.idx] # [W] self.num_nn = np.sum(self.nn, axis=1) # [W] # Sum over vertices of parameter variance multiplied by number of # nearest neighbours for each vertex trace_term = np.sum(self.sigmaK * self.num_nn) # [1] # Sum of nearest and next-nearest neighbour mean values self.sum_means_nn = np.matmul(self.nn, np.reshape(self.wK, (-1, 1))) # [W] self.sum_means_n2 = np.matmul(self.n2, tf.reshape(self.wK, (-1, 1))) # [W] # vertex parameter mean multipled by number of nearest neighbours wknn = self.wK * self.num_nn # [W] swk = wknn - self.sum_means_nn # [W] term2 = np.sum(swk * self.wK) # [1] gk = 1 / (0.5 * trace_term + 0.5 * term2 + 0.1) hk = float(self.nvertices) * 0.5 + 1.0 self.ak = gk * hk self.log.info("%s: ak=%f", self.name, self.ak) def _setup_mean_var(self, post_mean, post_cov): # This is the equivalent of ApplyToMVN in Fabber contrib_nn = self.log_tf(8*self.sum_means_nn, name="contrib_nn") # [W] contrib_n2 = self.log_tf(-self.sum_means_n2, name="contrib_n2") # [W] spatial_mean = self.log_tf(contrib_nn / (8*self.num_nn), name="spatial_mean") spatial_prec = self.log_tf(self.num_nn * self.ak, name="spatial_prec") self.var = self.log_tf(1 / (1/self.fixed_var + spatial_prec), name="%s_var" % self.name) #self.var = self.fixed_var self.mean = self.log_tf(self.var * spatial_prec * spatial_mean, name="%s_mean" % self.name)
#self.mean = self.fixed_mean + self.ak
[docs]class FactorisedPrior(Prior): """ Prior for a collection of parameters where there is no prior covariance In this case the mean log PDF can be summed from the contributions of each parameter """ def __init__(self, priors, **kwargs): Prior.__init__(self) self.priors = priors self.name = kwargs.get("name", "FactPrior") self.nparams = len(priors) means = [prior.mean for prior in self.priors] variances = [prior.var for prior in self.priors] self.mean = self.log_tf(tf.stack(means, axis=-1, name="%s_mean" % self.name)) self.var = self.log_tf(tf.stack(variances, axis=-1, name="%s_var" % self.name)) self.std = tf.sqrt(self.var, name="%s_std" % self.name) self.nvertices = priors[0].nvertices # Define a diagonal covariance matrix for convenience self.cov = tf.matrix_diag(self.var, name='%s_cov' % self.name)
[docs] def mean_log_pdf(self, samples): nvertices = tf.shape(samples)[0] mean_log_pdf = tf.zeros([nvertices], dtype=tf.float32) for idx, prior in enumerate(self.priors): param_samples = tf.slice(samples, [0, idx, 0], [-1, 1, -1]) param_logpdf = prior.mean_log_pdf(param_samples) mean_log_pdf = tf.add(mean_log_pdf, param_logpdf) return mean_log_pdf
[docs] def log_det_cov(self): """ Determinant of diagonal matrix is product of diagonal entries """ return tf.reduce_sum(tf.log(self.var), axis=1, name='%s_log_det_cov' % self.name)