Source code for pylfi.journal.journal

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import gridspec

plt.rc('text', usetex=True)


[docs]class Journal: r""" Journal class. Journal with results and information created by the run of inference schemes. """ def __init__(self): # list of parameter names (the 'name' kw from Prior object) self._param_names = [] # list of parameter LaTeX names (the 'tex' kw from Prior object) self._param_names_tex = [] # for tallying the number of inferred parameters self._n_params = 0 # sampler configuration and info self._info_df = {} # posterior samples dicts for data structures self._idata = {} self._idata_plot = {} self._df = {} self._df_plot = {} # flags # Becomes True if df with tex names are constructed self._df_plot_exist = False # Becomes True if idata with tex names are constructed self._idata_plot_exist = False # bool used to limit access if journal has not been written to self._journal_written = False def _write_to_journal( self, inference_scheme, observation, # simulator, # stat_calc, priors, n_samples, n_chains, n_sims, samples, accept_ratio, epsilon, quantile ): """ Write to journal """ self._observation = observation #self._simulator = simulator #self._stat_calc = stat_calc self._priors = priors # Extract parameter names and set up data structures for param in priors: self._param_names.append(param.name) self._idata[param.name] = None self._df[param.name] = None if param.tex is not None: self._param_names_tex.append(param.tex) self._idata_plot[param.tex] = None self._df_plot[param.tex] = None self._n_params += 1 if self._df_plot: self._df_plot_exist = True if self._idata_plot: self._idata_plot_exist = True # write to data structures self._write_config(inference_scheme, quantile, epsilon, accept_ratio) self._write_df(samples) #self._write_idata(samples, n_samples, n_chains) # Written to journal self._journal_written = True def _write_config(self, inference_scheme, quantile, epsilon, accept_ratio): self._info_df["Inference scheme"] = inference_scheme #self._info_df["Simulator model"] = self._simulator.__name__ self._info_df["quantile"] = quantile self._info_df["epsilon"] = epsilon self._info_df["accept_ratio"] = accept_ratio def _write_df(self, samples): for i, param_name in enumerate(self._param_names): self._df[param_name] = samples[:, i] if self._df_plot_exist: self._df_plot[self._param_names_tex[i]] = samples[:, i] self._df = pd.DataFrame(self._df) if self._df_plot_exist: self._df_plot = pd.DataFrame(self._df_plot) def _check_journal_status(self): """Check if journal has been initiated by an inference scheme. Parameters ---------- is_journal_started : bool ``True`` if the journal has been initiated by an inference scheme, ``False`` otherwise. Raises ------ RuntimeError If journal has not been initiated by an inference scheme. """ if not self._journal_written: msg = ("Journal unavailable; run an inference scheme first") raise RuntimeError(msg)
[docs] def thetas_pred(self, size=50): """Parameters drawn from posterior predictive distribution""" df = self.df idxs = np.random.randint(0, len(df.index), size) return df.iloc[idxs].to_numpy()
[docs] def compute_rmspe(self, theta_true, theta_pred): """Root mean square percentage error (RMSPE)""" rmspe = np.sqrt( np.mean( np.square( (theta_true - theta_pred) / theta_true) ) ) return rmspe * 100
[docs] def compute_hdi(self, theta, hdi_prob): """Highest (posterior) density interval""" theta = theta.flatten() n = len(theta) theta = np.sort(theta) interval_idx_inc = int(np.floor(hdi_prob * n)) n_intervals = n - interval_idx_inc interval_width = np.subtract(theta[interval_idx_inc:], theta[:n_intervals], dtype=np.float_ ) min_idx = np.argmin(interval_width) hdi_min = theta[min_idx] hdi_max = theta[min_idx + interval_idx_inc] return (hdi_min, hdi_max)
[docs] def compute_point_est(self, theta, density, point_estimate): """Compute point estimate""" if point_estimate == "mean": p_est = theta.mean() handle = "mean" elif point_estimate == "median": p_est = np.median(theta) handle = "median" elif point_estimate == "map": idx = np.argmax(density) p_est = theta[idx] handle = "MAP" return p_est, handle
def plot_prior( self, theta_name, x, color='C0', facecolor='lightblue', alpha=0.5, ax=None, **kwargs ): idx = np.where(self.df.columns.to_numpy() == theta_name) prior = self._priors[idx[0][0]] prior.plot_prior(x, ax=ax, **kwargs) def plot_posterior( self, theta_name, hdi_prob=0.95, point_estimate="map", theta_true=None, ax=None ): if ax is None: ax = plt.gca() if self._df_plot_exist: idx = np.where(self.df.columns.to_numpy() == theta_name) df = self._df_plot theta_name = self._df_plot.columns[idx][0] else: df = self.df xdata = df[theta_name].to_numpy() if theta_true is not None: theta_true_ary = np.ones(len(xdata.flatten())) * theta_true rmspe = self.compute_rmspe(theta_true_ary, xdata) sns.kdeplot( data=df, x=theta_name, color='C0', label=f"Posterior RMSPE: {rmspe:.2f}\%", ax=ax ) ax.axvline( theta_true, ymax=0.3, color='C3', label=fr"$\theta_\mathrm{{true}}: {theta_true}$" ) else: sns.kdeplot( data=df, x=x, color='C0', label="Posterior", ax=ax ) kdeline = ax.lines[0] xs = kdeline.get_xdata() ys = kdeline.get_ydata() p_est, p_est_handle = self.compute_point_est(xs, ys, point_estimate) hdi_min, hdi_max = self.compute_hdi(xdata, hdi_prob=hdi_prob) ax.vlines( p_est, 0, np.interp(p_est, xs, ys), color='b', ls=':', label=fr"$\hat{{\theta}}_\mathrm{{{p_est_handle}}}: {p_est:.3f}$" ) ax.fill_between( xs, 0, ys, facecolor='lightblue', alpha=0.3 ) ax.fill_between( xs, 0, ys, where=(hdi_min <= xs) & (xs <= hdi_max), interpolate=True, facecolor='steelblue', alpha=0.3, label=f"{hdi_prob*100}\% HDI: [{hdi_min:.3f}, {hdi_max:.3f}]" ) handles, labels = ax.get_legend_handles_labels() ax.set(yticks=[]) ax.legend( handles, labels, loc='center left', bbox_to_anchor=(1.04, 0.5), fancybox=True, borderaxespad=0.1, ncol=1, frameon=False ) def plot_joint( self, theta1_name, theta2_name, theta1_true=None, theta2_true=None, levels=6, alpha=0.3, height=4, **kwargs ): if self._df_plot_exist: df = self._df_plot idx1 = np.where(self.df.columns.to_numpy() == theta1_name) idx2 = np.where(self.df.columns.to_numpy() == theta2_name) theta1_name = self._df_plot.columns[idx1][0] theta2_name = self._df_plot.columns[idx2][0] else: df = self.df g = sns.jointplot( data=df, x=theta1_name, y=theta2_name, kind="kde", fill=True, height=height, **kwargs ) g.plot_joint( sns.kdeplot, color="k", levels=levels, alpha=alpha ) if theta1_true is not None and theta2_true is not None: g.ax_joint.plot([theta1_true], [theta2_true], 'ro') g.ax_joint.axvline(theta1_true, color='r', ls=":") g.ax_joint.axhline(theta2_true, color='r', ls=":") g.ax_marg_x.axvline(theta1_true, color='r') g.ax_marg_y.axhline(theta2_true, color='r') return g def heatmap( self, measure, cmap="coolwarm", vmin=None, vmax=None, ax=None, **kwargs ): if ax is None: ax = plt.gca() if self._df_plot_exist: df = self._df_plot else: df = self.df if measure == 'cov': res = df.cov() elif measure == 'corr': res = df.corr() else: raise ValueError(f'Unrecognized measure: {measure}') sns.heatmap( res, cmap=cmap, vmin=vmin, vmax=vmax, ax=ax, **kwargs ) @property def df(self): self._check_journal_status() return self._df
[docs] def save(self, filename): """ Stores the journal to disk. Function from ABCpy source code. Parameters ---------- filename: string the location of the file to store the current object to. """ with open(filename, 'wb') as output: pickle.dump(self, output, -1)
[docs] @classmethod def load(cls, filename): """This method reads a saved journal from disk an returns it as an object. Function from ABCpy source code. Notes ----- To store a journal use Journal.save(filename). Parameters ---------- filename: string The string representing the location of a file Returns ------- abcpy.output.Journal The journal object serialized in <filename> Example -------- >>> jnl = Journal.load('example_output.jnl') """ with open(filename, 'rb') as input: journal = pickle.load(input) return journal
if __name__ == "__main__": import pylfi filename = 'hh_rej_normal_best_posterior_org.jnl' journal = pylfi.Journal.load(filename) fig, ax = plt.subplots(figsize=(6, 4), tight_layout=True) #journal.plot_joint('gbarK', 'gbarNa', 36., 120.) #journal.plot_prior("gbarK", np.linspace(30, 40, 1000)) journal.plot_posterior("gbarNa", point_estimate='map', theta_true=120.,) plt.show()