Source code for mbproj2d.par

# Copyright (C) 2020 Jeremy Sanders <jeremy@jeremysanders.net>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import sys
import math
import pickle
import fnmatch
import re
import itertools

import numpy as N
import scipy.stats

from . import utils

[docs] class PriorBase: """Base class for all Priors.""" def calculate(self, val): return 0 def __repr__(self): return '<PriorBase: None>'
[docs] def paramFromUnit(self, unit): """Compute a parameter value to an input 0...1.""" return None
def copy(self): return PriorBase()
[docs] def bounds(self): """Return upper and lower bounds.""" return -N.inf, N.inf
[docs] def makeValidValue(self, val): """Transform value to pass to model. This can, for example, ensure that a model parameter is valid, even if the bounds are soft """ return val
[docs] class PriorFlat(PriorBase): """Flat prior. :param minval: minimum allowed value :param minval: maximum allowed value """ def __init__(self, minval, maxval): PriorBase.__init__(self) self.minval = minval self.maxval = maxval def calculate(self, val): if self.minval <= val <= self.maxval: return 0 else: return -N.inf def __repr__(self): return '<PriorFlat: minval=%s, maxval=%s>' % ( repr(self.minval), repr(self.maxval))
[docs] def paramFromUnit(self, unit): return (self.maxval-self.minval)*unit + self.minval
def copy(self): return PriorFlat(self.minval, self.maxval)
[docs] def bounds(self): return self.minval, self.maxval
[docs] class PriorFlatSoft(PriorBase): """Flat prior, where the likelihood is increase sharply increased at the bounds. The prior is exponential in width beyond the edges :param minval: soft minimum value :param maxval: soft maximum value :param width: width of transition beyond minval/maxval """ def __init__(self, minval, maxval, width=0.01): PriorBase.__init__(self) self.minval = minval self.maxval = maxval self.width = width def calculate(self, val): if val < self.minval: scale = (self.minval-val)/self.width if abs(scale > 40): return -N.inf return -(N.exp(scale)-1) elif val > self.maxval: scale = (val-self.maxval)/self.width if abs(scale > 40): return -N.inf return -(N.exp(scale)-1) else: return 0 def __repr__(self): return '<PriorFlatSoft: minval=%s, maxval=%s, width=%s>' % ( repr(self.minval), repr(self.maxval), repr(self.width) ) def copy(self): return PriorFlatSoft(self.minval, self.maxval, width=self.width)
[docs] def makeValidValue(self, val): return N.clip(val, self.minval, self.maxval)
[docs] class PriorGaussian(PriorBase): """Gaussian prior :param mu: Gaussian centre :param sigma: Gaussian width """ def __init__(self, mu, sigma): PriorBase.__init__(self) self.mu = mu self.sigma = sigma def calculate(self, val): if self.sigma <= 0: return -N.inf else: return ( -0.5*math.log(2*math.pi) -math.log(self.sigma) -0.5*((val - self.mu) / self.sigma)**2 ) def __repr__(self): return '<PriorGaussian: mu=%s, sigma=%s>' % ( self.mu, self.sigma)
[docs] def paramFromUnit(self, unit): return scipy.stats.norm.ppf(unit, self.mu, self.sigma)
def copy(self): return PriorGaussian(self.mu, self.sigma)
[docs] class PriorBoundedGaussian(PriorBase): """Gaussian prior Bounded :param mu: Gaussian centre :param sigma: Gaussian width :param minval: minimum allowed value :param maxval: maximum allowed value """ def __init__(self, mu, sigma, minval=None, maxval=None): PriorBase.__init__(self) self.mu = mu self.sigma = sigma if minval is None: minval = -N.inf if maxval is None: maxval = +N.inf self.minval = minval self.maxval = maxval def calculate(self, val): if self.sigma > 0 and ( self.minval <= val <= self.maxval ): return ( -0.5*math.log(2*math.pi) -math.log(self.sigma) -0.5*((val - self.mu) / self.sigma)**2 ) else: return -N.inf def __repr__(self): return '<PriorBoundedGaussian: mu=%s, sigma=%s, minval=%s, maxval=%s>' % ( self.mu, self.sigma, self.minval, self.maxval)
[docs] def paramFromUnit(self, unit): a = (self.minval - self.mu) / self.sigma b = (self.maxval - self.mu) / self.sigma return scipy.stats.truncnorm.ppf( unit, a, b, loc=self.mu, scale=self.sigma)
def copy(self): return PriorBoundedGaussian( self.mu, self.sigma, self.minval, self.maxval)
[docs] def bounds(self): return self.minval, self.maxval
[docs] class PriorBoundedGaussianSoft(PriorBase): """Gaussian prior Bounded with soft limits :param mu: Gaussian centre :param sigma: Gaussian width :param minval: minimum allowed value :param maxval: maximum allowed value """ def __init__(self, mu, sigma, minval=None, maxval=None, width=0.01): PriorBase.__init__(self) self.mu = mu self.sigma = sigma if minval is None: minval = -N.inf if maxval is None: maxval = +N.inf self.minval = minval self.maxval = maxval self.width = width def calculate(self, val): if self.sigma < 0: return -N.inf gfunc = ( -0.5*math.log(2*math.pi) -math.log(self.sigma) -0.5*((val - self.mu) / self.sigma)**2 ) # add on sharp edges if val < self.minval: scale = (self.minval-val)/self.width gfunc -= N.exp(scale)-1 elif val > self.maxval: scale = (val-self.maxval)/self.width gfunc -= N.exp(scale)-1 return gfunc def __repr__(self): return '<PriorBoundedGaussianSoft: mu=%s, sigma=%s, minval=%s, maxval=%s, width=%s>' % ( self.mu, self.sigma, self.minval, self.maxval, self.width)
[docs] def paramFromUnit(self, unit): a = (self.minval - self.mu) / self.sigma b = (self.maxval - self.mu) / self.sigma return scipy.stats.truncnorm.ppf( unit, a, b, loc=self.mu, scale=self.sigma)
def copy(self): return PriorBoundedGaussianSoft( self.mu, self.sigma, self.minval, self.maxval, width=self.width)
[docs] def makeValidValue(self, val): return N.clip(val, self.minval, self.maxval)
[docs] class Par: """Parameter for model. :param float val: parameter value :param prior: prior object or None for flat prior :param frozen: whether to leave parameter frozen :param xform: function to transform value for model or 'exp' for an exp(x) scaling :param linked: another Par object to link this parameter to another :param float minval: minimum value for default flat prior :param float maxval: maximum value for default flat prior :param soft: use a soft flat prior instead of a sharp one """ def __init__( self, val, prior=None, frozen=False, xform=None, linked=None, minval=-N.inf, maxval=N.inf, soft=False): self.val = val self.frozen = frozen if prior is None: if soft: self.prior = PriorFlatSoft(minval, maxval) else: self.prior = PriorFlat(minval, maxval) else: self.prior = prior if xform is None: self.xform = None elif xform == 'exp': self.xform = lambda x: math.exp(x) else: self.xform = xform self.linked = linked @property def v(self): """Value for using in model, after transformation or linking, if any.""" if self.linked is None: val = self.val else: val = self.linked.val val = self.prior.makeValidValue(val) if self.xform is None: return val else: return self.xform(val)
[docs] def isFree(self): """Is the parameter free?""" return self.linked is None and not self.frozen
[docs] def calcPrior(self): """Calculate prior.""" if self.linked is not None: return 0 else: return self.prior.calculate(self.val)
def __repr__(self): if self.linked is not None: p = [ 'linked=%s' % self.linked, ] else: p = [ 'val=%.5g' % self.val, 'frozen=%s' % self.frozen, ] p.append('prior=%s' % repr(self.prior)) if self.xform is not None: p.append('xform=%s' % self.xform) return '<Par: %s>' % (', '.join(p)) def copy(self): # linking is not deep copied: this is fixed by Pars below return Par( self.val, prior=self.prior.copy(), frozen=self.frozen, linked=self.linked, xform=self.xform)
[docs] class Pars(dict): """Parameters for a model. This is based around a dictionary class. Each parameter has a name. """
[docs] def numFree(self): """Return number of free parameters""" return len(self.freeKeys())
[docs] def freeKeys(self): """Return sorted list of keys of parameters which are free""" return [key for key in sorted(self) if self[key].isFree()]
[docs] def freeVals(self): """Return list of values for parameters which are free in sorted key order.""" return [par.val for key, par in sorted(self.items()) if par.isFree()]
[docs] def setFree(self, vals): """Given a list of values, set those which are free. Note: number of free parameters should be number of vals """ i = 0 for key in sorted(self): par = self[key] if par.isFree(): par.val = vals[i] i += 1
[docs] def calcPrior(self): """Return total prior of parameters.""" return sum((par.calcPrior() for par in self.values()))
def __repr__(self): # sorted repr (to match above) out = [] for key in sorted(self): out.append('%s: %s' % (repr(key), repr(self[key]))) return '{%s}' % (', '.join(out))
[docs] def write(self, file=sys.stdout): """Print out parameters.""" vtok = {v: k for k, v in self.items()} for k, v in sorted(self.items()): out = [ '%16s:' % k, ] if v.linked: out += [ '%12s' % vtok[v.linked], 'linked' ] else: out += [ '%12g' % v.val, 'frozen' if v.frozen else 'thawed', ] out.append('%45s' % repr(v.prior)) if v.xform: out.append('xform=%s' % repr(v.xform)) utils.uprint(' '.join(out), file=file)
[docs] def copy(self): """Return a deep copy of self.""" newpars = Pars() for k, v in self.items(): newpars[k] = v.copy() # fixup links to point to new parameters. vtok = {v: k for k, v in self.items()} for k, v in newpars.items(): if v.linked is not None: v.linked = newpars[vtok[v.linked]] return newpars
[docs] def save(self, filename): """Saves the parameters as a Python pickle. *Note*: upgrading the source code of mbproj2d or your prior may prevent the saved file from being loadable again. Take care before relying on this for long term storage. :param filename: output filename """ with open(filename, 'wb') as f: pickle.dump(self, f)
[docs] def load(self, filename, skip=False): """Load parameters from file. :param filename: filename to load from :param skip: if set, then we continue if file not found """ try: with open(filename, 'rb') as f: pars = pickle.load(f) except OSError as e: if skip: return else: raise e if len(self) != len(pars): raise RuntimeError("Number of parameters loaded does not match number of parameters") self.update(pars)
[docs] def match(self, pattern, use_re=False): """Returns a dictionary of parameters whose names match a pattern. :param pattern: glob-style parameter match, e.g. "ne_*" or "abc_???_alpha" (default), a regular expression string (if use_re) :param use_re: if set, treat pattern as a regular expression Returns {'name': par, ...} """ out = {} for name in self: if ( (use_re and re.match(pattern, name) is not None) or (not use_re and fnmatch.fnmatchcase(name, pattern)) ): out[name] = self[name] return out
[docs] def matchFreeze(self, pattern, use_re=False): """Freeze parameters which match the name given. :param pattern: glob-style parameter match, e.g. "ne_*" or "abc_???_alpha" :param use_re: if set, treat pattern as a regular expression """ for par in self.match(pattern, use_re=use_re).values(): par.frozen = True
[docs] def matchThaw(self, pattern, use_re=False): """Thaw parameters which match the name given. :param pattern: glob-style parameter match, e.g. "ne_*" or "abc_???_alpha". :param use_re: if set, treat pattern as a regular expression """ for par in self.match(pattern, use_re=use_re).values(): par.frozen = False
[docs] def matchSet(self, pattern, val, use_re=False): """Set values for parameters which match the name given. :param pattern: glob-style parameter match, e.g. "ne_*" or "abc_???_alpha". :param val: constant (to set to same value) or iterable (to set to sequence) """ try: valiter = iter(val) except TypeError: valiter = itertools.repeat(val) for par in self.match(pattern, use_re=use_re).values(): par.val = next(valiter)
[docs] def bounds(self): """Return lower,upper bounds for free parameters.""" lower = [] upper = [] for par in self.freeKeys(): l, u = self[par].prior.bounds() lower.append(l) upper.append(u) return lower, upper