Source code for svb.parameter

"""
SVB - Model parameters

This module defines a set of classes of model parameters.

The factory methods which create priors/posteriors can
make use of the instance class to create the appropriate
type of vertexwise prior/posterior
"""
try:
    import tensorflow.compat.v1 as tf
except ImportError:
    import tensorflow as tf
   
from .utils import LogBase
from . import dist

[docs]def get_parameter(name, **kwargs): """ Factory method to create an instance of a parameter """ custom_kwargs = kwargs.pop("param_overrides", {}).get(name, {}) kwargs.update(custom_kwargs) desc = kwargs.get("desc", "No description given") prior_dist = dist.get_dist(prefix="prior", **kwargs) prior_type = kwargs.get("prior_type", "N") post_dist = dist.get_dist(prefix="post", **kwargs) post_type = kwargs.get("post_type", "vertexwise") post_init = kwargs.get("post_init", None) return Parameter(name, desc=desc, prior=prior_dist, prior_type=prior_type, post=post_dist, post_init=post_init, post_type=post_type)
[docs]class Parameter(LogBase): """ A standard model parameter """ def __init__(self, name, **kwargs): """ Constructor :param name: Parameter name :param prior: Dist instance giving the parameter's prior distribution :param desc: Optional parameter description Keyword arguments (optional): - ``mean_init`` Initial value for the posterior mean either as a numeric value or a callable which takes the parameters t, data, param_name - ``log_var_init`` Initial value for the posterior log variance either as a numeric value or a callable which takes the parameters t, data, param_name - ``param_overrides`` Dictionary keyed by parameter name. Value should be dictionary of keyword arguments which will override those defined as existing keyword arguments """ LogBase.__init__(self) custom_kwargs = kwargs.pop("param_overrides", {}).get(name, {}) kwargs.update(custom_kwargs) self.name = name self.desc = kwargs.get("desc", "No description given") self.prior_dist = kwargs.get("prior") self.prior_type = kwargs.get("prior_type", "N") self.post_dist = kwargs.get("post", self.prior_dist) self.post_type = kwargs.get("post_type", "vertexwise") self.post_init = kwargs.get("post_init", None) def __str__(self): return "Parameter: %s" % self.name