98 lines
4.1 KiB
Python
98 lines
4.1 KiB
Python
#
|
|
# Tests for the lambertw function,
|
|
# Adapted from the MPMath tests [1] by Yosef Meller, mellerf@netvision.net.il
|
|
# Distributed under the same license as SciPy itself.
|
|
#
|
|
# [1] mpmath source code, Subversion revision 992
|
|
# http://code.google.com/p/mpmath/source/browse/trunk/mpmath/tests/test_functions2.py?spec=svn994&r=992
|
|
|
|
import numpy as np
|
|
from numpy.testing import assert_, assert_equal, assert_array_almost_equal
|
|
from scipy.special import lambertw
|
|
from numpy import nan, inf, pi, e, isnan, log, r_, array, complex_
|
|
|
|
from scipy.special._testutils import FuncData
|
|
|
|
|
|
def test_values():
|
|
assert_(isnan(lambertw(nan)))
|
|
assert_equal(lambertw(inf,1).real, inf)
|
|
assert_equal(lambertw(inf,1).imag, 2*pi)
|
|
assert_equal(lambertw(-inf,1).real, inf)
|
|
assert_equal(lambertw(-inf,1).imag, 3*pi)
|
|
|
|
assert_equal(lambertw(1.), lambertw(1., 0))
|
|
|
|
data = [
|
|
(0,0, 0),
|
|
(0+0j,0, 0),
|
|
(inf,0, inf),
|
|
(0,-1, -inf),
|
|
(0,1, -inf),
|
|
(0,3, -inf),
|
|
(e,0, 1),
|
|
(1,0, 0.567143290409783873),
|
|
(-pi/2,0, 1j*pi/2),
|
|
(-log(2)/2,0, -log(2)),
|
|
(0.25,0, 0.203888354702240164),
|
|
(-0.25,0, -0.357402956181388903),
|
|
(-1./10000,0, -0.000100010001500266719),
|
|
(-0.25,-1, -2.15329236411034965),
|
|
(0.25,-1, -3.00899800997004620-4.07652978899159763j),
|
|
(-0.25,-1, -2.15329236411034965),
|
|
(0.25,1, -3.00899800997004620+4.07652978899159763j),
|
|
(-0.25,1, -3.48973228422959210+7.41405453009603664j),
|
|
(-4,0, 0.67881197132094523+1.91195078174339937j),
|
|
(-4,1, -0.66743107129800988+7.76827456802783084j),
|
|
(-4,-1, 0.67881197132094523-1.91195078174339937j),
|
|
(1000,0, 5.24960285240159623),
|
|
(1000,1, 4.91492239981054535+5.44652615979447070j),
|
|
(1000,-1, 4.91492239981054535-5.44652615979447070j),
|
|
(1000,5, 3.5010625305312892+29.9614548941181328j),
|
|
(3+4j,0, 1.281561806123775878+0.533095222020971071j),
|
|
(-0.4+0.4j,0, -0.10396515323290657+0.61899273315171632j),
|
|
(3+4j,1, -0.11691092896595324+5.61888039871282334j),
|
|
(3+4j,-1, 0.25856740686699742-3.85211668616143559j),
|
|
(-0.5,-1, -0.794023632344689368-0.770111750510379110j),
|
|
(-1./10000,1, -11.82350837248724344+6.80546081842002101j),
|
|
(-1./10000,-1, -11.6671145325663544),
|
|
(-1./10000,-2, -11.82350837248724344-6.80546081842002101j),
|
|
(-1./100000,4, -14.9186890769540539+26.1856750178782046j),
|
|
(-1./100000,5, -15.0931437726379218666+32.5525721210262290086j),
|
|
((2+1j)/10,0, 0.173704503762911669+0.071781336752835511j),
|
|
((2+1j)/10,1, -3.21746028349820063+4.56175438896292539j),
|
|
((2+1j)/10,-1, -3.03781405002993088-3.53946629633505737j),
|
|
((2+1j)/10,4, -4.6878509692773249+23.8313630697683291j),
|
|
(-(2+1j)/10,0, -0.226933772515757933-0.164986470020154580j),
|
|
(-(2+1j)/10,1, -2.43569517046110001+0.76974067544756289j),
|
|
(-(2+1j)/10,-1, -3.54858738151989450-6.91627921869943589j),
|
|
(-(2+1j)/10,4, -4.5500846928118151+20.6672982215434637j),
|
|
(pi,0, 1.073658194796149172092178407024821347547745350410314531),
|
|
|
|
# Former bug in generated branch,
|
|
(-0.5+0.002j,0, -0.78917138132659918344 + 0.76743539379990327749j),
|
|
(-0.5-0.002j,0, -0.78917138132659918344 - 0.76743539379990327749j),
|
|
(-0.448+0.4j,0, -0.11855133765652382241 + 0.66570534313583423116j),
|
|
(-0.448-0.4j,0, -0.11855133765652382241 - 0.66570534313583423116j),
|
|
]
|
|
data = array(data, dtype=complex_)
|
|
|
|
def w(x, y):
|
|
return lambertw(x, y.real.astype(int))
|
|
with np.errstate(all='ignore'):
|
|
FuncData(w, data, (0,1), 2, rtol=1e-10, atol=1e-13).check()
|
|
|
|
|
|
def test_ufunc():
|
|
assert_array_almost_equal(
|
|
lambertw(r_[0., e, 1.]), r_[0., 1., 0.567143290409783873])
|
|
|
|
|
|
def test_lambertw_ufunc_loop_selection():
|
|
# see https://github.com/scipy/scipy/issues/4895
|
|
dt = np.dtype(np.complex128)
|
|
assert_equal(lambertw(0, 0, 0).dtype, dt)
|
|
assert_equal(lambertw([0], 0, 0).dtype, dt)
|
|
assert_equal(lambertw(0, [0], 0).dtype, dt)
|
|
assert_equal(lambertw(0, 0, [0]).dtype, dt)
|
|
assert_equal(lambertw([0], [0], [0]).dtype, dt)
|