Fixed database typo and removed unnecessary class identifier.
This commit is contained in:
parent
00ad49a143
commit
45fb349a7d
5098 changed files with 952558 additions and 85 deletions
229
venv/Lib/site-packages/scipy/io/matlab/tests/test_streams.py
Normal file
229
venv/Lib/site-packages/scipy/io/matlab/tests/test_streams.py
Normal file
|
@ -0,0 +1,229 @@
|
|||
""" Testing
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import zlib
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
from tempfile import mkstemp
|
||||
from contextlib import contextmanager
|
||||
|
||||
import numpy as np
|
||||
|
||||
from numpy.testing import assert_, assert_equal
|
||||
from pytest import raises as assert_raises
|
||||
|
||||
from scipy.io.matlab.streams import (make_stream,
|
||||
GenericStream, ZlibInputStream,
|
||||
_read_into, _read_string, BLOCK_SIZE)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def setup_test_file():
|
||||
val = b'a\x00string'
|
||||
fd, fname = mkstemp()
|
||||
|
||||
with os.fdopen(fd, 'wb') as fs:
|
||||
fs.write(val)
|
||||
with open(fname, 'rb') as fs:
|
||||
gs = BytesIO(val)
|
||||
cs = BytesIO(val)
|
||||
yield fs, gs, cs
|
||||
os.unlink(fname)
|
||||
|
||||
|
||||
def test_make_stream():
|
||||
with setup_test_file() as (fs, gs, cs):
|
||||
# test stream initialization
|
||||
assert_(isinstance(make_stream(gs), GenericStream))
|
||||
|
||||
|
||||
def test_tell_seek():
|
||||
with setup_test_file() as (fs, gs, cs):
|
||||
for s in (fs, gs, cs):
|
||||
st = make_stream(s)
|
||||
res = st.seek(0)
|
||||
assert_equal(res, 0)
|
||||
assert_equal(st.tell(), 0)
|
||||
res = st.seek(5)
|
||||
assert_equal(res, 0)
|
||||
assert_equal(st.tell(), 5)
|
||||
res = st.seek(2, 1)
|
||||
assert_equal(res, 0)
|
||||
assert_equal(st.tell(), 7)
|
||||
res = st.seek(-2, 2)
|
||||
assert_equal(res, 0)
|
||||
assert_equal(st.tell(), 6)
|
||||
|
||||
|
||||
def test_read():
|
||||
with setup_test_file() as (fs, gs, cs):
|
||||
for s in (fs, gs, cs):
|
||||
st = make_stream(s)
|
||||
st.seek(0)
|
||||
res = st.read(-1)
|
||||
assert_equal(res, b'a\x00string')
|
||||
st.seek(0)
|
||||
res = st.read(4)
|
||||
assert_equal(res, b'a\x00st')
|
||||
# read into
|
||||
st.seek(0)
|
||||
res = _read_into(st, 4)
|
||||
assert_equal(res, b'a\x00st')
|
||||
res = _read_into(st, 4)
|
||||
assert_equal(res, b'ring')
|
||||
assert_raises(IOError, _read_into, st, 2)
|
||||
# read alloc
|
||||
st.seek(0)
|
||||
res = _read_string(st, 4)
|
||||
assert_equal(res, b'a\x00st')
|
||||
res = _read_string(st, 4)
|
||||
assert_equal(res, b'ring')
|
||||
assert_raises(IOError, _read_string, st, 2)
|
||||
|
||||
|
||||
class TestZlibInputStream(object):
|
||||
def _get_data(self, size):
|
||||
data = np.random.randint(0, 256, size).astype(np.uint8).tobytes()
|
||||
compressed_data = zlib.compress(data)
|
||||
stream = BytesIO(compressed_data)
|
||||
return stream, len(compressed_data), data
|
||||
|
||||
def test_read(self):
|
||||
SIZES = [0, 1, 10, BLOCK_SIZE//2, BLOCK_SIZE-1,
|
||||
BLOCK_SIZE, BLOCK_SIZE+1, 2*BLOCK_SIZE-1]
|
||||
|
||||
READ_SIZES = [BLOCK_SIZE//2, BLOCK_SIZE-1,
|
||||
BLOCK_SIZE, BLOCK_SIZE+1]
|
||||
|
||||
def check(size, read_size):
|
||||
compressed_stream, compressed_data_len, data = self._get_data(size)
|
||||
stream = ZlibInputStream(compressed_stream, compressed_data_len)
|
||||
data2 = b''
|
||||
so_far = 0
|
||||
while True:
|
||||
block = stream.read(min(read_size,
|
||||
size - so_far))
|
||||
if not block:
|
||||
break
|
||||
so_far += len(block)
|
||||
data2 += block
|
||||
assert_equal(data, data2)
|
||||
|
||||
for size in SIZES:
|
||||
for read_size in READ_SIZES:
|
||||
check(size, read_size)
|
||||
|
||||
def test_read_max_length(self):
|
||||
size = 1234
|
||||
data = np.random.randint(0, 256, size).astype(np.uint8).tobytes()
|
||||
compressed_data = zlib.compress(data)
|
||||
compressed_stream = BytesIO(compressed_data + b"abbacaca")
|
||||
stream = ZlibInputStream(compressed_stream, len(compressed_data))
|
||||
|
||||
stream.read(len(data))
|
||||
assert_equal(compressed_stream.tell(), len(compressed_data))
|
||||
|
||||
assert_raises(IOError, stream.read, 1)
|
||||
|
||||
def test_read_bad_checksum(self):
|
||||
data = np.random.randint(0, 256, 10).astype(np.uint8).tobytes()
|
||||
compressed_data = zlib.compress(data)
|
||||
|
||||
# break checksum
|
||||
compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
|
||||
|
||||
compressed_stream = BytesIO(compressed_data)
|
||||
stream = ZlibInputStream(compressed_stream, len(compressed_data))
|
||||
|
||||
assert_raises(zlib.error, stream.read, len(data))
|
||||
|
||||
def test_seek(self):
|
||||
compressed_stream, compressed_data_len, data = self._get_data(1024)
|
||||
|
||||
stream = ZlibInputStream(compressed_stream, compressed_data_len)
|
||||
|
||||
stream.seek(123)
|
||||
p = 123
|
||||
assert_equal(stream.tell(), p)
|
||||
d1 = stream.read(11)
|
||||
assert_equal(d1, data[p:p+11])
|
||||
|
||||
stream.seek(321, 1)
|
||||
p = 123+11+321
|
||||
assert_equal(stream.tell(), p)
|
||||
d2 = stream.read(21)
|
||||
assert_equal(d2, data[p:p+21])
|
||||
|
||||
stream.seek(641, 0)
|
||||
p = 641
|
||||
assert_equal(stream.tell(), p)
|
||||
d3 = stream.read(11)
|
||||
assert_equal(d3, data[p:p+11])
|
||||
|
||||
assert_raises(IOError, stream.seek, 10, 2)
|
||||
assert_raises(IOError, stream.seek, -1, 1)
|
||||
assert_raises(ValueError, stream.seek, 1, 123)
|
||||
|
||||
stream.seek(10000, 1)
|
||||
assert_raises(IOError, stream.read, 12)
|
||||
|
||||
def test_seek_bad_checksum(self):
|
||||
data = np.random.randint(0, 256, 10).astype(np.uint8).tobytes()
|
||||
compressed_data = zlib.compress(data)
|
||||
|
||||
# break checksum
|
||||
compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
|
||||
|
||||
compressed_stream = BytesIO(compressed_data)
|
||||
stream = ZlibInputStream(compressed_stream, len(compressed_data))
|
||||
|
||||
assert_raises(zlib.error, stream.seek, len(data))
|
||||
|
||||
def test_all_data_read(self):
|
||||
compressed_stream, compressed_data_len, data = self._get_data(1024)
|
||||
stream = ZlibInputStream(compressed_stream, compressed_data_len)
|
||||
assert_(not stream.all_data_read())
|
||||
stream.seek(512)
|
||||
assert_(not stream.all_data_read())
|
||||
stream.seek(1024)
|
||||
assert_(stream.all_data_read())
|
||||
|
||||
def test_all_data_read_overlap(self):
|
||||
COMPRESSION_LEVEL = 6
|
||||
|
||||
data = np.arange(33707000).astype(np.uint8).tobytes()
|
||||
compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
|
||||
compressed_data_len = len(compressed_data)
|
||||
|
||||
# check that part of the checksum overlaps
|
||||
assert_(compressed_data_len == BLOCK_SIZE + 2)
|
||||
|
||||
compressed_stream = BytesIO(compressed_data)
|
||||
stream = ZlibInputStream(compressed_stream, compressed_data_len)
|
||||
assert_(not stream.all_data_read())
|
||||
stream.seek(len(data))
|
||||
assert_(stream.all_data_read())
|
||||
|
||||
def test_all_data_read_bad_checksum(self):
|
||||
COMPRESSION_LEVEL = 6
|
||||
|
||||
data = np.arange(33707000).astype(np.uint8).tobytes()
|
||||
compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
|
||||
compressed_data_len = len(compressed_data)
|
||||
|
||||
# check that part of the checksum overlaps
|
||||
assert_(compressed_data_len == BLOCK_SIZE + 2)
|
||||
|
||||
# break checksum
|
||||
compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
|
||||
|
||||
compressed_stream = BytesIO(compressed_data)
|
||||
stream = ZlibInputStream(compressed_stream, compressed_data_len)
|
||||
assert_(not stream.all_data_read())
|
||||
stream.seek(len(data))
|
||||
|
||||
assert_raises(zlib.error, stream.all_data_read)
|
Loading…
Add table
Add a link
Reference in a new issue