Source code for pycmdstan.io

"""
I/O functions for working with CmdStan executables.

"""

import os
import re
import threading
import numpy as np


def _rdump_array(key, val):
    c = 'c(' + ', '.join(map(str, val.T.flat)) + ')'
    if (val.size, ) == val.shape:
        return '{key} <- {c}'.format(key=key, c=c)
    else:
        dim = '.Dim = c{0}'.format(val.shape)
        struct = '{key} <- structure({c}, {dim})'.format(key=key, c=c, dim=dim)
        return struct


[docs]def rdump(fname, data): """Dump a dict of data to a R dump format file. """ with open(fname, 'w') as fd: for key, val in data.items(): if isinstance(val, np.ndarray) and val.size > 1: line = _rdump_array(key, val) else: try: val = val.flat[0] except: pass line = '%s <- %s' % (key, val) fd.write(line) fd.write('\n')
[docs]def rload(fname): """Load a dict of data from an R dump format file. """ with open(fname, 'r') as fd: lines = fd.readlines() data = {} for line in lines: lhs, rhs = [_.strip() for _ in line.split('<-')] if rhs.startswith('structure'): *_, vals, dim = rhs.replace('(', ' ').replace(')', ' ').split('c') vals = [float(v) for v in vals.split(',')[:-1]] dim = [int(v) for v in dim.split(',')] val = np.array(vals).reshape(dim[::-1]).T elif rhs.startswith('c'): val = np.array([float(_) for _ in rhs[2:-1].split(',')]) else: try: val = int(rhs) except: try: val = float(rhs) except: raise ValueError(rhs) data[lhs] = val return data
[docs]def merge_csv_data(*csvs, skip=0): """Merge multiple CSV dicts into a single dict. """ data_ = {} for csv in csvs: for key, val in csv.items(): # XXX do better if key in 'loo loos ks'.split(): continue val = val[skip:] if key in data_: data_[key] = np.concatenate((data_[key], val), axis=0) else: data_[key] = val return data_
[docs]def parse_csv(fname, merge=True): """Parse samples from a Stan output CSV file. """ if '*' in fname: import glob return parse_csv(glob.glob(fname), merge=merge) if isinstance(fname, (list, tuple)): csv = [] for _ in fname: try: csv.append(parse_csv(_)) except Exception as e: print('skipping ', fname, e) if merge: csv = merge_csv_data(*csv) return csv lines = [] with open(fname, 'r') as fd: for line in fd.readlines(): if not line.startswith('#'): lines.append(line.strip().split(',')) names = [field.split('.') for field in lines[0]] data = np.array([[float(f) for f in line] for line in lines[1:]]) namemap = {} maxdims = {} for i, name in enumerate(names): if name[0] not in namemap: namemap[name[0]] = [] namemap[name[0]].append(i) if len(name) > 1: maxdims[name[0]] = name[1:] for name in maxdims.keys(): dims = [] for dim in maxdims[name]: dims.append(int(dim)) maxdims[name] = tuple(reversed(dims)) # data in linear order per Stan, e.g. mat is col maj # TODO array is row maj, how to distinguish matrix v array[,]? data_ = {} for name, idx in namemap.items(): new_shape = (-1, ) + maxdims.get(name, ()) data_[name] = data[:, idx].reshape(new_shape) return data_
[docs]def parse_summary_csv(fname): """Parse CSV output of the stansummary program. """ skeys = [] svals = [] niter = -1 with open(fname, 'r') as fd: scols = fd.readline().strip().split(',') for line in fd.readlines(): if 'iterations' in line: niter_match = re.search(r'(\d+) iterations saved', line) if niter_match: niter = int(niter_match.group(1)) continue if line.startswith('#') or '"' not in line: continue _, k, v = line.split('"') skeys.append(k) svals.append(np.array([float(_) for _ in v.split(',')[1:]])) svals = np.array(svals) sdat = {} sdims = {} for skey, sval in zip(skeys, svals): if '[' in skey: name, dim = skey.replace('[', ']').split(']')[:-1] dim = tuple(int(i) for i in dim.split(',')) sdims[name] = dim if name not in sdat: sdat[name] = [] sdat[name].append(sval) else: sdat[skey] = sval for key in [_ for _ in sdat.keys()]: if key in sdims: sdat[key] = np.array(sdat[key]).reshape(sdims[key] + (-1, )) recs = {} dt = [(k, 'f8') for k in scols[1:]] for key, val in sdat.items(): recs[key] = np.rec.array(val, dtype=dt) return niter, recs
# class OnlineCSVParser: # # TODO following lines + col labels is sufficient to alloc arrays AOT # # num_samples = 1000 (Default) # # num_warmup = 1000 (Default) # # save_warmup = 0 (Default) # def __init__(self, csv_fname): # self.csv_fname = csv_fname # self.thread = Thread(target=self._run) # self._line = '' # self.read = True # self.thread.start() # def _run(self): # while True: # try: # self._follow() # except Exception as exc: # print(exc) # def _follow(self): # with open(self.csv_fname, 'r') as fd: # while True: # line = fd.readline() # if line: # self.parse_line(line) # else: # # TODO adapt to avg time btw samples # time.sleep(0.01) # def parse_line(self, line): # pass