import numpy as np import numpy.testing as npt import pytest from pytest import raises as assert_raises from scipy.integrate import IntegrationWarning from scipy import stats from scipy.special import betainc from. common_tests import (check_normalization, check_moment, check_mean_expect, check_var_expect, check_skew_expect, check_kurt_expect, check_entropy, check_private_entropy, check_entropy_vect_scale, check_edge_support, check_named_args, check_random_state_property, check_meth_dtype, check_ppf_dtype, check_cmplx_deriv, check_pickling, check_rvs_broadcast, check_freezing) from scipy.stats._distr_params import distcont """ Test all continuous distributions. Parameters were chosen for those distributions that pass the Kolmogorov-Smirnov test. This provides safe parameters for each distributions so that we can perform further testing of class methods. These tests currently check only/mostly for serious errors and exceptions, not for numerically exact results. """ # Note that you need to add new distributions you want tested # to _distr_params DECIMAL = 5 # specify the precision of the tests # increased from 0 to 5 # Last three of these fail all around. Need to be checked distcont_extra = [ ['betaprime', (100, 86)], ['fatiguelife', (5,)], ['invweibull', (0.58847112119264788,)], # burr: sample mean test fails still for c<1 ['burr', (0.94839838075366045, 4.3820284068855795)], # genextreme: sample mean test, sf-logsf test fail ['genextreme', (3.3184017469423535,)], ] distslow = ['kstwo', 'ksone', 'kappa4', 'gausshyper', 'recipinvgauss', 'genexpon', 'vonmises', 'vonmises_line', 'cosine', 'invweibull', 'powerlognorm', 'johnsonsu', 'kstwobign'] # distslow are sorted by speed (very slow to slow) # skip check_fit_args (test is slow) skip_fit_test = ['exponpow', 'exponweib', 'gausshyper', 'genexpon', 'halfgennorm', 'gompertz', 'johnsonsb', 'johnsonsu', 'kappa4', 'ksone', 'kstwo', 'kstwobign', 'mielke', 'ncf', 'nct', 'powerlognorm', 'powernorm', 'recipinvgauss', 'trapz', 'vonmises', 'vonmises_line', 'levy_stable', 'rv_histogram_instance'] # skip check_fit_args_fix (test is slow) skip_fit_fix_test = ['burr', 'exponpow', 'exponweib', 'gausshyper', 'genexpon', 'halfgennorm', 'gompertz', 'johnsonsb', 'johnsonsu', 'kappa4', 'ksone', 'kstwo', 'kstwobign', 'levy_stable', 'mielke', 'ncf', 'ncx2', 'powerlognorm', 'powernorm', 'rdist', 'recipinvgauss', 'trapz', 'vonmises', 'vonmises_line'] # These distributions fail the complex derivative test below. # Here 'fail' mean produce wrong results and/or raise exceptions, depending # on the implementation details of corresponding special functions. # cf https://github.com/scipy/scipy/pull/4979 for a discussion. fails_cmplx = set(['beta', 'betaprime', 'chi', 'chi2', 'dgamma', 'dweibull', 'erlang', 'f', 'gamma', 'gausshyper', 'gengamma', 'geninvgauss', 'gennorm', 'genpareto', 'halfgennorm', 'invgamma', 'ksone', 'kstwo', 'kstwobign', 'levy_l', 'loggamma', 'logistic', 'loguniform', 'maxwell', 'nakagami', 'ncf', 'nct', 'ncx2', 'norminvgauss', 'pearson3', 'rdist', 'reciprocal', 'rice', 'skewnorm', 't', 'tukeylambda', 'vonmises', 'vonmises_line', 'rv_histogram_instance']) _h = np.histogram([1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 8, 8, 9], bins=8) histogram_test_instance = stats.rv_histogram(_h) def cases_test_cont_basic(): for distname, arg in distcont[:] + [(histogram_test_instance, tuple())]: if distname == 'levy_stable': continue elif distname in distslow: yield pytest.param(distname, arg, marks=pytest.mark.slow) else: yield distname, arg @pytest.mark.parametrize('distname,arg', cases_test_cont_basic()) def test_cont_basic(distname, arg): # this test skips slow distributions if distname == 'truncnorm': pytest.xfail(reason=distname) try: distfn = getattr(stats, distname) except TypeError: distfn = distname distname = 'rv_histogram_instance' np.random.seed(765456) sn = 500 with npt.suppress_warnings() as sup: # frechet_l and frechet_r are deprecated, so all their # methods generate DeprecationWarnings. sup.filter(category=DeprecationWarning, message=".*frechet_") rvs = distfn.rvs(size=sn, *arg) sm = rvs.mean() sv = rvs.var() m, v = distfn.stats(*arg) check_sample_meanvar_(distfn, arg, m, v, sm, sv, sn, distname + 'sample mean test') check_cdf_ppf(distfn, arg, distname) check_sf_isf(distfn, arg, distname) check_pdf(distfn, arg, distname) check_pdf_logpdf(distfn, arg, distname) check_pdf_logpdf_at_endpoints(distfn, arg, distname) check_cdf_logcdf(distfn, arg, distname) check_sf_logsf(distfn, arg, distname) check_ppf_broadcast(distfn, arg, distname) alpha = 0.01 if distname == 'rv_histogram_instance': check_distribution_rvs(distfn.cdf, arg, alpha, rvs) elif distname != 'geninvgauss': # skip kstest for geninvgauss since cdf is too slow; see test for # rv generation in TestGenInvGauss in test_distributions.py check_distribution_rvs(distname, arg, alpha, rvs) locscale_defaults = (0, 1) meths = [distfn.pdf, distfn.logpdf, distfn.cdf, distfn.logcdf, distfn.logsf] # make sure arguments are within support spec_x = {'frechet_l': -0.5, 'weibull_max': -0.5, 'levy_l': -0.5, 'pareto': 1.5, 'tukeylambda': 0.3, 'rv_histogram_instance': 5.0} x = spec_x.get(distname, 0.5) if distname == 'invweibull': arg = (1,) elif distname == 'ksone': arg = (3,) check_named_args(distfn, x, arg, locscale_defaults, meths) check_random_state_property(distfn, arg) check_pickling(distfn, arg) check_freezing(distfn, arg) # Entropy if distname not in ['kstwobign', 'kstwo']: check_entropy(distfn, arg, distname) if distfn.numargs == 0: check_vecentropy(distfn, arg) if (distfn.__class__._entropy != stats.rv_continuous._entropy and distname != 'vonmises'): check_private_entropy(distfn, arg, stats.rv_continuous) with npt.suppress_warnings() as sup: sup.filter(IntegrationWarning, "The occurrence of roundoff error") sup.filter(IntegrationWarning, "Extremely bad integrand") sup.filter(RuntimeWarning, "invalid value") check_entropy_vect_scale(distfn, arg) check_retrieving_support(distfn, arg) check_edge_support(distfn, arg) check_meth_dtype(distfn, arg, meths) check_ppf_dtype(distfn, arg) if distname not in fails_cmplx: check_cmplx_deriv(distfn, arg) if distname != 'truncnorm': check_ppf_private(distfn, arg, distname) if distname not in skip_fit_test: check_fit_args(distfn, arg, rvs[0:200]) if distname not in skip_fit_fix_test: check_fit_args_fix(distfn, arg, rvs[0:200]) @pytest.mark.parametrize('distname,arg', cases_test_cont_basic()) def test_rvs_scalar(distname, arg): # rvs should return a scalar when given scalar arguments (gh-12428) try: distfn = getattr(stats, distname) except TypeError: distfn = distname distname = 'rv_histogram_instance' with npt.suppress_warnings() as sup: sup.filter(category=DeprecationWarning, message=".*frechet_") rvs = distfn.rvs(*arg) assert np.isscalar(distfn.rvs(*arg)) assert np.isscalar(distfn.rvs(*arg, size=())) assert np.isscalar(distfn.rvs(*arg, size=None)) def test_levy_stable_random_state_property(): # levy_stable only implements rvs(), so it is skipped in the # main loop in test_cont_basic(). Here we apply just the test # check_random_state_property to levy_stable. check_random_state_property(stats.levy_stable, (0.5, 0.1)) def cases_test_moments(): fail_normalization = set(['vonmises']) fail_higher = set(['vonmises', 'ncf']) for distname, arg in distcont[:] + [(histogram_test_instance, tuple())]: if distname == 'levy_stable': continue cond1 = distname not in fail_normalization cond2 = distname not in fail_higher yield distname, arg, cond1, cond2, False if not cond1 or not cond2: # Run the distributions that have issues twice, once skipping the # not_ok parts, once with the not_ok parts but marked as knownfail yield pytest.param(distname, arg, True, True, True, marks=pytest.mark.xfail) @pytest.mark.slow @pytest.mark.parametrize('distname,arg,normalization_ok,higher_ok,is_xfailing', cases_test_moments()) def test_moments(distname, arg, normalization_ok, higher_ok, is_xfailing): try: distfn = getattr(stats, distname) except TypeError: distfn = distname distname = 'rv_histogram_instance' with npt.suppress_warnings() as sup: sup.filter(IntegrationWarning, "The integral is probably divergent, or slowly convergent.") sup.filter(category=DeprecationWarning, message=".*frechet_") if is_xfailing: sup.filter(IntegrationWarning) m, v, s, k = distfn.stats(*arg, moments='mvsk') if normalization_ok: check_normalization(distfn, arg, distname) if higher_ok: check_mean_expect(distfn, arg, m, distname) check_skew_expect(distfn, arg, m, v, s, distname) check_var_expect(distfn, arg, m, v, distname) check_kurt_expect(distfn, arg, m, v, k, distname) check_loc_scale(distfn, arg, m, v, distname) check_moment(distfn, arg, m, v, distname) @pytest.mark.parametrize('dist,shape_args', distcont) def test_rvs_broadcast(dist, shape_args): if dist in ['gausshyper', 'genexpon']: pytest.skip("too slow") # If shape_only is True, it means the _rvs method of the # distribution uses more than one random number to generate a random # variate. That means the result of using rvs with broadcasting or # with a nontrivial size will not necessarily be the same as using the # numpy.vectorize'd version of rvs(), so we can only compare the shapes # of the results, not the values. # Whether or not a distribution is in the following list is an # implementation detail of the distribution, not a requirement. If # the implementation the rvs() method of a distribution changes, this # test might also have to be changed. shape_only = dist in ['argus', 'betaprime', 'dgamma', 'dweibull', 'exponnorm', 'geninvgauss', 'levy_stable', 'nct', 'norminvgauss', 'rice', 'skewnorm', 'semicircular'] distfunc = getattr(stats, dist) loc = np.zeros(2) scale = np.ones((3, 1)) nargs = distfunc.numargs allargs = [] bshape = [3, 2] # Generate shape parameter arguments... for k in range(nargs): shp = (k + 4,) + (1,)*(k + 2) allargs.append(shape_args[k]*np.ones(shp)) bshape.insert(0, k + 4) allargs.extend([loc, scale]) # bshape holds the expected shape when loc, scale, and the shape # parameters are all broadcast together. check_rvs_broadcast(distfunc, dist, allargs, bshape, shape_only, 'd') def test_rvs_gh2069_regression(): # Regression tests for gh-2069. In scipy 0.17 and earlier, # these tests would fail. # # A typical example of the broken behavior: # >>> norm.rvs(loc=np.zeros(5), scale=np.ones(5)) # array([-2.49613705, -2.49613705, -2.49613705, -2.49613705, -2.49613705]) np.random.seed(123) vals = stats.norm.rvs(loc=np.zeros(5), scale=1) d = np.diff(vals) npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!") vals = stats.norm.rvs(loc=0, scale=np.ones(5)) d = np.diff(vals) npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!") vals = stats.norm.rvs(loc=np.zeros(5), scale=np.ones(5)) d = np.diff(vals) npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!") vals = stats.norm.rvs(loc=np.array([[0], [0]]), scale=np.ones(5)) d = np.diff(vals.ravel()) npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!") assert_raises(ValueError, stats.norm.rvs, [[0, 0], [0, 0]], [[1, 1], [1, 1]], 1) assert_raises(ValueError, stats.gamma.rvs, [2, 3, 4, 5], 0, 1, (2, 2)) assert_raises(ValueError, stats.gamma.rvs, [1, 1, 1, 1], [0, 0, 0, 0], [[1], [2]], (4,)) def test_nomodify_gh9900_regression(): # Regression test for gh-9990 # Prior to gh-9990, calls to stats.truncnorm._cdf() use what ever was # set inside the stats.truncnorm instance during stats.truncnorm.cdf(). # This could cause issues wth multi-threaded code. # Since then, the calls to cdf() are not permitted to modify the global # stats.truncnorm instance. tn = stats.truncnorm # Use the right-half truncated normal # Check that the cdf and _cdf return the same result. npt.assert_almost_equal(tn.cdf(1, 0, np.inf), 0.6826894921370859) npt.assert_almost_equal(tn._cdf(1, 0, np.inf), 0.6826894921370859) # Now use the left-half truncated normal npt.assert_almost_equal(tn.cdf(-1, -np.inf, 0), 0.31731050786291415) npt.assert_almost_equal(tn._cdf(-1, -np.inf, 0), 0.31731050786291415) # Check that the right-half truncated normal _cdf hasn't changed npt.assert_almost_equal(tn._cdf(1, 0, np.inf), 0.6826894921370859) # NOT 1.6826894921370859 npt.assert_almost_equal(tn.cdf(1, 0, np.inf), 0.6826894921370859) # Check that the left-half truncated normal _cdf hasn't changed npt.assert_almost_equal(tn._cdf(-1, -np.inf, 0), 0.31731050786291415) # Not -0.6826894921370859 npt.assert_almost_equal(tn.cdf(1, -np.inf, 0), 1) # Not 1.6826894921370859 npt.assert_almost_equal(tn.cdf(-1, -np.inf, 0), 0.31731050786291415) # Not -0.6826894921370859 def test_broadcast_gh9990_regression(): # Regression test for gh-9990 # The x-value 7 only lies within the support of 4 of the supplied # distributions. Prior to 9990, one array passed to # stats.reciprocal._cdf would have 4 elements, but an array # previously stored by stats.reciprocal_argcheck() would have 6, leading # to a broadcast error. a = np.array([1, 2, 3, 4, 5, 6]) b = np.array([8, 16, 1, 32, 1, 48]) ans = [stats.reciprocal.cdf(7, _a, _b) for _a, _b in zip(a,b)] npt.assert_array_almost_equal(stats.reciprocal.cdf(7, a, b), ans) ans = [stats.reciprocal.cdf(1, _a, _b) for _a, _b in zip(a,b)] npt.assert_array_almost_equal(stats.reciprocal.cdf(1, a, b), ans) ans = [stats.reciprocal.cdf(_a, _a, _b) for _a, _b in zip(a,b)] npt.assert_array_almost_equal(stats.reciprocal.cdf(a, a, b), ans) ans = [stats.reciprocal.cdf(_b, _a, _b) for _a, _b in zip(a,b)] npt.assert_array_almost_equal(stats.reciprocal.cdf(b, a, b), ans) def test_broadcast_gh7933_regression(): # Check broadcast works stats.truncnorm.logpdf( np.array([3.0, 2.0, 1.0]), a=(1.5 - np.array([6.0, 5.0, 4.0])) / 3.0, b=np.inf, loc=np.array([6.0, 5.0, 4.0]), scale=3.0 ) def test_gh2002_regression(): # Add a check that broadcast works in situations where only some # x-values are compatible with some of the shape arguments. x = np.r_[-2:2:101j] a = np.r_[-np.ones(50), np.ones(51)] expected = [stats.truncnorm.pdf(_x, _a, np.inf) for _x, _a in zip(x, a)] ans = stats.truncnorm.pdf(x, a, np.inf) npt.assert_array_almost_equal(ans, expected) def test_gh1320_regression(): # Check that the first example from gh-1320 now works. c = 2.62 stats.genextreme.ppf(0.5, np.array([[c], [c + 0.5]])) # The other examples in gh-1320 appear to have stopped working # some time ago. # ans = stats.genextreme.moment(2, np.array([c, c + 0.5])) # expected = np.array([25.50105963, 115.11191437]) # stats.genextreme.moment(5, np.array([[c], [c + 0.5]])) # stats.genextreme.moment(5, np.array([c, c + 0.5])) def check_sample_meanvar_(distfn, arg, m, v, sm, sv, sn, msg): # this did not work, skipped silently by nose if np.isfinite(m): check_sample_mean(sm, sv, sn, m) if np.isfinite(v): check_sample_var(sv, sn, v) def check_sample_mean(sm, v, n, popmean): # from stats.stats.ttest_1samp(a, popmean): # Calculates the t-obtained for the independent samples T-test on ONE group # of scores a, given a population mean. # # Returns: t-value, two-tailed prob df = n-1 svar = ((n-1)*v) / float(df) # looks redundant t = (sm-popmean) / np.sqrt(svar*(1.0/n)) prob = betainc(0.5*df, 0.5, df/(df + t*t)) # return t,prob npt.assert_(prob > 0.01, 'mean fail, t,prob = %f, %f, m, sm=%f,%f' % (t, prob, popmean, sm)) def check_sample_var(sv, n, popvar): # two-sided chisquare test for sample variance equal to # hypothesized variance df = n-1 chi2 = (n-1)*popvar/float(popvar) pval = stats.distributions.chi2.sf(chi2, df) * 2 npt.assert_(pval > 0.01, 'var fail, t, pval = %f, %f, v, sv=%f, %f' % (chi2, pval, popvar, sv)) def check_cdf_ppf(distfn, arg, msg): values = [0.001, 0.5, 0.999] npt.assert_almost_equal(distfn.cdf(distfn.ppf(values, *arg), *arg), values, decimal=DECIMAL, err_msg=msg + ' - cdf-ppf roundtrip') def check_sf_isf(distfn, arg, msg): npt.assert_almost_equal(distfn.sf(distfn.isf([0.1, 0.5, 0.9], *arg), *arg), [0.1, 0.5, 0.9], decimal=DECIMAL, err_msg=msg + ' - sf-isf roundtrip') npt.assert_almost_equal(distfn.cdf([0.1, 0.9], *arg), 1.0 - distfn.sf([0.1, 0.9], *arg), decimal=DECIMAL, err_msg=msg + ' - cdf-sf relationship') def check_pdf(distfn, arg, msg): # compares pdf at median with numerical derivative of cdf median = distfn.ppf(0.5, *arg) eps = 1e-6 pdfv = distfn.pdf(median, *arg) if (pdfv < 1e-4) or (pdfv > 1e4): # avoid checking a case where pdf is close to zero or # huge (singularity) median = median + 0.1 pdfv = distfn.pdf(median, *arg) cdfdiff = (distfn.cdf(median + eps, *arg) - distfn.cdf(median - eps, *arg))/eps/2.0 # replace with better diff and better test (more points), # actually, this works pretty well msg += ' - cdf-pdf relationship' npt.assert_almost_equal(pdfv, cdfdiff, decimal=DECIMAL, err_msg=msg) def check_pdf_logpdf(distfn, args, msg): # compares pdf at several points with the log of the pdf points = np.array([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) vals = distfn.ppf(points, *args) vals = vals[np.isfinite(vals)] pdf = distfn.pdf(vals, *args) logpdf = distfn.logpdf(vals, *args) pdf = pdf[(pdf != 0) & np.isfinite(pdf)] logpdf = logpdf[np.isfinite(logpdf)] msg += " - logpdf-log(pdf) relationship" npt.assert_almost_equal(np.log(pdf), logpdf, decimal=7, err_msg=msg) def check_pdf_logpdf_at_endpoints(distfn, args, msg): # compares pdf with the log of the pdf at the (finite) end points points = np.array([0, 1]) vals = distfn.ppf(points, *args) vals = vals[np.isfinite(vals)] with npt.suppress_warnings() as sup: # Several distributions incur divide by zero or encounter invalid values when computing # the pdf or logpdf at the endpoints. suppress_messsages = [ "divide by zero encountered in true_divide", # multiple distributions "divide by zero encountered in log", # multiple distributions "divide by zero encountered in power", # gengamma "invalid value encountered in add", # genextreme "invalid value encountered in subtract", # gengamma "invalid value encountered in multiply" # recipinvgauss ] for msg in suppress_messsages: sup.filter(category=RuntimeWarning, message=msg) pdf = distfn.pdf(vals, *args) logpdf = distfn.logpdf(vals, *args) pdf = pdf[(pdf != 0) & np.isfinite(pdf)] logpdf = logpdf[np.isfinite(logpdf)] msg += " - logpdf-log(pdf) relationship" npt.assert_almost_equal(np.log(pdf), logpdf, decimal=7, err_msg=msg) def check_sf_logsf(distfn, args, msg): # compares sf at several points with the log of the sf points = np.array([0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1.0]) vals = distfn.ppf(points, *args) vals = vals[np.isfinite(vals)] sf = distfn.sf(vals, *args) logsf = distfn.logsf(vals, *args) sf = sf[sf != 0] logsf = logsf[np.isfinite(logsf)] msg += " - logsf-log(sf) relationship" npt.assert_almost_equal(np.log(sf), logsf, decimal=7, err_msg=msg) def check_cdf_logcdf(distfn, args, msg): # compares cdf at several points with the log of the cdf points = np.array([0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1.0]) vals = distfn.ppf(points, *args) vals = vals[np.isfinite(vals)] cdf = distfn.cdf(vals, *args) logcdf = distfn.logcdf(vals, *args) cdf = cdf[cdf != 0] logcdf = logcdf[np.isfinite(logcdf)] msg += " - logcdf-log(cdf) relationship" npt.assert_almost_equal(np.log(cdf), logcdf, decimal=7, err_msg=msg) def check_ppf_broadcast(distfn, arg, msg): # compares ppf for multiple argsets. num_repeats = 5 args = [] * num_repeats if arg: args = [np.array([_] * num_repeats) for _ in arg] median = distfn.ppf(0.5, *arg) medians = distfn.ppf(0.5, *args) msg += " - ppf multiple" npt.assert_almost_equal(medians, [median] * num_repeats, decimal=7, err_msg=msg) def check_distribution_rvs(dist, args, alpha, rvs): # dist is either a cdf function or name of a distribution in scipy.stats. # args are the args for scipy.stats.dist(*args) # alpha is a significance level, ~0.01 # rvs is array_like of random variables # test from scipy.stats.tests # this version reuses existing random variables D, pval = stats.kstest(rvs, dist, args=args, N=1000) if (pval < alpha): # The rvs passed in failed the K-S test, which _could_ happen # but is unlikely if alpha is small enough. # Repeat the the test with a new sample of rvs. # Generate 1000 rvs, perform a K-S test that the new sample of rvs # are distributed according to the distribution. D, pval = stats.kstest(dist, dist, args=args, N=1000) npt.assert_(pval > alpha, "D = " + str(D) + "; pval = " + str(pval) + "; alpha = " + str(alpha) + "\nargs = " + str(args)) def check_vecentropy(distfn, args): npt.assert_equal(distfn.vecentropy(*args), distfn._entropy(*args)) def check_loc_scale(distfn, arg, m, v, msg): loc, scale = 10.0, 10.0 mt, vt = distfn.stats(loc=loc, scale=scale, *arg) npt.assert_allclose(m*scale + loc, mt) npt.assert_allclose(v*scale*scale, vt) def check_ppf_private(distfn, arg, msg): # fails by design for truncnorm self.nb not defined ppfs = distfn._ppf(np.array([0.1, 0.5, 0.9]), *arg) npt.assert_(not np.any(np.isnan(ppfs)), msg + 'ppf private is nan') def check_retrieving_support(distfn, args): loc, scale = 1, 2 supp = distfn.support(*args) supp_loc_scale = distfn.support(*args, loc=loc, scale=scale) npt.assert_almost_equal(np.array(supp)*scale + loc, np.array(supp_loc_scale)) def check_fit_args(distfn, arg, rvs): with np.errstate(all='ignore'), npt.suppress_warnings() as sup: sup.filter(category=DeprecationWarning, message=".*frechet_") sup.filter(category=RuntimeWarning, message="The shape parameter of the erlang") sup.filter(category=RuntimeWarning, message="floating point number truncated") vals = distfn.fit(rvs) vals2 = distfn.fit(rvs, optimizer='powell') # Only check the length of the return # FIXME: should check the actual results to see if we are 'close' # to what was created --- but what is 'close' enough npt.assert_(len(vals) == 2+len(arg)) npt.assert_(len(vals2) == 2+len(arg)) def check_fit_args_fix(distfn, arg, rvs): with np.errstate(all='ignore'), npt.suppress_warnings() as sup: sup.filter(category=DeprecationWarning, message=".*frechet_") sup.filter(category=RuntimeWarning, message="The shape parameter of the erlang") vals = distfn.fit(rvs, floc=0) vals2 = distfn.fit(rvs, fscale=1) npt.assert_(len(vals) == 2+len(arg)) npt.assert_(vals[-2] == 0) npt.assert_(vals2[-1] == 1) npt.assert_(len(vals2) == 2+len(arg)) if len(arg) > 0: vals3 = distfn.fit(rvs, f0=arg[0]) npt.assert_(len(vals3) == 2+len(arg)) npt.assert_(vals3[0] == arg[0]) if len(arg) > 1: vals4 = distfn.fit(rvs, f1=arg[1]) npt.assert_(len(vals4) == 2+len(arg)) npt.assert_(vals4[1] == arg[1]) if len(arg) > 2: vals5 = distfn.fit(rvs, f2=arg[2]) npt.assert_(len(vals5) == 2+len(arg)) npt.assert_(vals5[2] == arg[2]) @pytest.mark.parametrize('method', ['pdf', 'logpdf', 'cdf', 'logcdf', 'sf', 'logsf', 'ppf', 'isf']) @pytest.mark.parametrize('distname, args', distcont) def test_methods_with_lists(method, distname, args): # Test that the continuous distributions can accept Python lists # as arguments. with npt.suppress_warnings() as sup: sup.filter(category=DeprecationWarning, message=".*frechet_") dist = getattr(stats, distname) f = getattr(dist, method) if distname == 'invweibull' and method.startswith('log'): x = [1.5, 2] else: x = [0.1, 0.2] shape2 = [[a]*2 for a in args] loc = [0, 0.1] scale = [1, 1.01] result = f(x, *shape2, loc=loc, scale=scale) npt.assert_allclose(result, [f(*v) for v in zip(x, *shape2, loc, scale)], rtol=1e-15, atol=1e-15)