Source code for statsmodels.tools.dump2module

'''Save a set of numpy arrays to a python module file that can be imported

Author : Josef Perktold
'''
from __future__ import print_function
from statsmodels.compat.python import iterkeys
import numpy as np

[docs]class HoldIt(object): '''Class to write numpy arrays into a python module Calling save on the instance of this class write all attributes of the instance into a module file. For details see the save method. '''
[docs] def __init__(self, name): self.name = name
[docs] def save(self, what=None, filename=None, header=True, useinstance=True, comment=None, print_options=None): '''write attributes of this instance to python module given by filename Parameters ---------- what : list or None list of attributes that are added to the module. If None (default) then all attributes in __dict__ that do not start with an underline will be saved. filename : string specifies filename with path. If the file does not exist, it will be created. If the file is already exists, then the new data will be appended to the file. header : bool If true, then the imports of the module and the class definition are written before writing the data. useinstance : bool If true, then the data in the module are attached to an instance of a holder class. If false, then each array will be saved as separate variable. comment : string If comment is not empty then this string will be attached as a description comment to the data instance in the saved module. print_options : dict or None The print_options for the numpy arrays will be updated with this. see notes Notes ----- The content of an numpy array are written using repr, which can be controlled with the np.set_printoptions. The numpy default is updated with: precision=20, linewidth=100, nanstr='nan', infstr='inf' This should provide enough precision for double floating point numbers. If one array has more than 1000 elements, then threshold should be overwritten by the user, see keyword argument print_options. ''' print_opt_old = np.get_printoptions() print_opt = dict(precision=20, linewidth=100, nanstr='nan', infstr='inf') if print_options: print_opt.update(print_options) np.set_printoptions(**print_opt) #precision corrects for non-scientific notation if what is None: what = (i for i in self.__dict__ if i[0] != '_') if header: txt = ['import numpy as np\n' 'from numpy import array, rec, inf, nan\n\n'] if useinstance: txt.append('class Holder(object):\n pass\n\n') else: txt = [] if useinstance: txt.append('%s = Holder()' % self.name) prefix = '%s.' % self.name else: prefix = '' if not comment is None: txt.append("%scomment = '%s'" % (prefix, comment)) for x in what: txt.append('%s%s = %s' % (prefix, x, repr(getattr(self,x)))) txt.extend(['','']) #add empty lines at end if not filename is None: file(filename, 'a+').write('\n'.join(txt)) np.set_printoptions(**print_opt_old) self._filename = filename self._useinstance = useinstance self._what = what return txt
[docs] def verify(self): '''load the saved module and verify the data This tries several ways of comparing the saved and the attached data, but might not work for all possible data structures. Returns ------- all_correct : bool true if no differences are found, for floating point numbers rtol=1e-16, atol=1e-16 is used to determine equality (allclose) correctli : list list of attribute names that compare as equal incorrectli : list list of attribute names that did not compare as equal, either because they differ or because the comparison does not handle the data structure correctly ''' module = __import__(self._filename.replace('.py','')) if not self._useinstance: raise NotImplementedError('currently only implemented when' 'useinstance is true') data = getattr(module, self.name) correctli = [] incorrectli = [] for d in self._what: self_item = getattr(data, d) saved_item = getattr(data, d) #print(d) #try simple equality correct = np.all(self.item == saved_item) #try allclose if not correct and not self.item.dtype == np.dtype('object'): correct = np.allclose(self_item, saved_item, rtol=1e-16, atol=1e-16) if not correct: import warnings warnings.warn("inexact precision in "+d, RuntimeWarning) #try iterating, if object array if not correct: correlem =[np.all(data[d].item()[k] == getattr(testsave.var_results, d).item()[k]) for k in iterkeys(data[d].item())] if not correlem: #print(d, "wrong") incorrectli.append(d) correctli.append(d) return len(incorrectli)==0, correctli, incorrectli
if __name__ == '__main__': data = np.load(r"E:\Josef\eclipsegworkspace\statsmodels-josef-experimental-030\dist\statsmodels-0.3.0dev_with_Winhelp_a2\statsmodels-0.3.0dev\scikits\statsmodels\tsa\vector_ar\tests\results\vars_results.npz") res_var = HoldIt('var_results') for d in data: setattr(res_var, d, data[d]) np.set_printoptions(precision=120, linewidth=100) res_var.save(filename='testsave.py', header=True, comment='VAR test data converted from vars_results.npz') import testsave for d in data: print(d) correct = np.all(data[d] == getattr(testsave.var_results, d)) if not correct and not data[d].dtype == np.dtype('object'): correct = np.allclose(data[d], getattr(testsave.var_results, d), rtol=1e-16, atol=1e-16) if not correct: print("inexact precision") if not correct: correlem =[np.all(data[d].item()[k] == getattr(testsave.var_results, d).item()[k]) for k in iterkeys(data[d].item())] if not correlem: print(d, "wrong") print(res_var.verify())