from .io import parse_csv
from math import sqrt
import numpy as np
import pylab as pl
def _plot_key(fn, csv, *keys):
alpha = 1.0 / sqrt(len(keys))
for key in keys:
fn(csv[key], label=key, alpha=alpha)
pl.legend()
[docs]def plot_key(csv, *keys):
"""Create a trace plot for keys (parameter names) in sample set.
"""
_plot_key(pl.plot, csv, *keys)
[docs]def hist_key(csv, *keys):
"""Create a histogram for keys (parameter names) in sample set.
"""
def _(x, **kwargs):
pl.hist(x.reshape((-1, )), bins=int(sqrt(len(x))), **kwargs)
_plot_key(_, csv, *keys)
[docs]def trace_nuts(csv, extras='', skip=0, n_col=4):
"""Trace plots of NUTS state, along with other model parameters.
"""
from pylab import subplot, plot, gca, title, grid, xticks
if isinstance(csv, dict):
csv = [csv]
if isinstance(extras, str):
extras = extras.split()
n_nuts_params = 7
n_subplots = len(extras) + n_nuts_params
n_row = n_subplots // n_col + 1
for csvi in csv:
i = 1
for key in csvi.keys():
if key.endswith('__') or key in extras:
subplot(n_row, n_col, i)
plot(csvi[key][skip:], alpha=0.5)
if key in ('stepsize__', ):
gca().set_yscale('log')
title(key)
grid(1)
# e.g. 3x3 subplots, w/ 7 used, want 5,6,7 to have x ticks
if i < (n_subplots - n_col):
xticks(xticks()[0], [])
i += 1
[docs]def pairs(csv, keys, skip=0):
"""Create a pairs plot for keys in the given dataset.
"""
import pylab as pl
n = len(keys)
if isinstance(csv, dict):
csv = [csv] # following assumes list of chains' results
for i, key_i in enumerate(keys):
for j, key_j in enumerate(keys):
pl.subplot(n, n, i * n + j + 1)
for csvi in csv:
if i == j:
pl.hist(csvi[key_i][skip:], 20, log=True)
else:
pl.plot(csvi[key_j][skip:], csvi[key_i][skip:], '.')
if i == 0:
pl.title(key_j)
if j == 0:
pl.ylabel(key_i)
[docs]def parallel_coordinates(csv, keys, marker='ko-'):
"""Create a parallel coordinates plot for keys in the given dataset.
"""
nsamp = csv['lp__'].shape[0]
flats = {k: v.reshape((nsamp, -1)) for k, v in csv.items() if k in keys}
key_i = 0
mats = []
key_idx = []
key_val = []
for k, v in flats.items():
mats.append(v)
key_idx.append(key_i)
key_val.append(k)
key_i += v.shape[1]
mats = np.hstack(mats)
mats = ((mats - mats.min(axis=0)) / mats.ptp(axis=0)).T
pl.plot(mats, 'ko-', alpha=1 / np.sqrt(nsamp))
pl.xticks(key_idx, key_val)
pl.yticks([])
pl.grid(1)