Fixed database typo and removed unnecessary class identifier.
This commit is contained in:
parent
00ad49a143
commit
45fb349a7d
5098 changed files with 952558 additions and 85 deletions
940
venv/Lib/site-packages/imageio/plugins/_bsdf.py
Normal file
940
venv/Lib/site-packages/imageio/plugins/_bsdf.py
Normal file
|
|
@ -0,0 +1,940 @@
|
|||
#!/usr/bin/env python
|
||||
# This file is distributed under the terms of the 2-clause BSD License.
|
||||
# Copyright (c) 2017-2018, Almar Klein
|
||||
|
||||
"""
|
||||
Python implementation of the Binary Structured Data Format (BSDF).
|
||||
|
||||
BSDF is a binary format for serializing structured (scientific) data.
|
||||
See http://bsdf.io for more information.
|
||||
|
||||
This is the reference implementation, which is relatively relatively
|
||||
sophisticated, providing e.g. lazy loading of blobs and streamed
|
||||
reading/writing. A simpler Python implementation is available as
|
||||
``bsdf_lite.py``.
|
||||
|
||||
This module has no dependencies and works on Python 2.7 and 3.4+.
|
||||
|
||||
Note: on Legacy Python (Python 2.7), non-Unicode strings are encoded as bytes.
|
||||
"""
|
||||
|
||||
# todo: in 2020, remove six stuff, __future__ and _isidentifier
|
||||
# todo: in 2020, remove 'utf-8' args to encode/decode; it's faster
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import bz2
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
import types
|
||||
import zlib
|
||||
from io import BytesIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Notes on versioning: the major and minor numbers correspond to the
|
||||
# BSDF format version. The major number if increased when backward
|
||||
# incompatible changes are introduced. An implementation must raise an
|
||||
# exception when the file being read has a higher major version. The
|
||||
# minor number is increased when new backward compatible features are
|
||||
# introduced. An implementation must display a warning when the file
|
||||
# being read has a higher minor version. The patch version is increased
|
||||
# for subsequent releases of the implementation.
|
||||
VERSION = 2, 1, 2
|
||||
__version__ = ".".join(str(i) for i in VERSION)
|
||||
|
||||
|
||||
# %% The encoder and decoder implementation
|
||||
|
||||
# From six.py
|
||||
PY3 = sys.version_info[0] >= 3
|
||||
if PY3:
|
||||
text_type = str
|
||||
string_types = str
|
||||
unicode_types = str
|
||||
integer_types = int
|
||||
classtypes = type
|
||||
else: # pragma: no cover
|
||||
logging.basicConfig() # avoid "no handlers found" error
|
||||
text_type = unicode # noqa
|
||||
string_types = basestring # noqa
|
||||
unicode_types = unicode # noqa
|
||||
integer_types = (int, long) # noqa
|
||||
classtypes = type, types.ClassType
|
||||
|
||||
# Shorthands
|
||||
spack = struct.pack
|
||||
strunpack = struct.unpack
|
||||
|
||||
|
||||
def lencode(x):
|
||||
""" Encode an unsigned integer into a variable sized blob of bytes.
|
||||
"""
|
||||
# We could support 16 bit and 32 bit as well, but the gain is low, since
|
||||
# 9 bytes for collections with over 250 elements is marginal anyway.
|
||||
if x <= 250:
|
||||
return spack("<B", x)
|
||||
# elif x < 65536:
|
||||
# return spack('<BH', 251, x)
|
||||
# elif x < 4294967296:
|
||||
# return spack('<BI', 252, x)
|
||||
else:
|
||||
return spack("<BQ", 253, x)
|
||||
|
||||
|
||||
# Include len decoder for completeness; we've inlined it for performance.
|
||||
def lendecode(f):
|
||||
""" Decode an unsigned integer from a file.
|
||||
"""
|
||||
n = strunpack("<B", f.read(1))[0]
|
||||
if n == 253:
|
||||
n = strunpack("<Q", f.read(8))[0] # noqa
|
||||
return n
|
||||
|
||||
|
||||
def encode_type_id(b, ext_id):
|
||||
""" Encode the type identifier, with or without extension id.
|
||||
"""
|
||||
if ext_id is not None:
|
||||
bb = ext_id.encode("UTF-8")
|
||||
return b.upper() + lencode(len(bb)) + bb # noqa
|
||||
else:
|
||||
return b # noqa
|
||||
|
||||
|
||||
def _isidentifier(s): # pragma: no cover
|
||||
""" Use of str.isidentifier() for Legacy Python, but slower.
|
||||
"""
|
||||
# http://stackoverflow.com/questions/2544972/
|
||||
return (
|
||||
isinstance(s, string_types)
|
||||
and re.match(r"^\w+$", s, re.UNICODE)
|
||||
and re.match(r"^[0-9]", s) is None
|
||||
)
|
||||
|
||||
|
||||
class BsdfSerializer(object):
|
||||
""" Instances of this class represent a BSDF encoder/decoder.
|
||||
|
||||
It acts as a placeholder for a set of extensions and encoding/decoding
|
||||
options. Use this to predefine extensions and options for high
|
||||
performance encoding/decoding. For general use, see the functions
|
||||
`save()`, `encode()`, `load()`, and `decode()`.
|
||||
|
||||
This implementation of BSDF supports streaming lists (keep adding
|
||||
to a list after writing the main file), lazy loading of blobs, and
|
||||
in-place editing of blobs (for streams opened with a+).
|
||||
|
||||
Options for encoding:
|
||||
|
||||
* compression (int or str): ``0`` or "no" for no compression (default),
|
||||
``1`` or "zlib" for Zlib compression (same as zip files and PNG), and
|
||||
``2`` or "bz2" for Bz2 compression (more compact but slower writing).
|
||||
Note that some BSDF implementations (e.g. JavaScript) may not support
|
||||
compression.
|
||||
* use_checksum (bool): whether to include a checksum with binary blobs.
|
||||
* float64 (bool): Whether to write floats as 64 bit (default) or 32 bit.
|
||||
|
||||
Options for decoding:
|
||||
|
||||
* load_streaming (bool): if True, and the final object in the structure was
|
||||
a stream, will make it available as a stream in the decoded object.
|
||||
* lazy_blob (bool): if True, bytes are represented as Blob objects that can
|
||||
be used to lazily access the data, and also overwrite the data if the
|
||||
file is open in a+ mode.
|
||||
"""
|
||||
|
||||
def __init__(self, extensions=None, **options):
|
||||
self._extensions = {} # name -> extension
|
||||
self._extensions_by_cls = {} # cls -> (name, extension.encode)
|
||||
if extensions is None:
|
||||
extensions = standard_extensions
|
||||
for extension in extensions:
|
||||
self.add_extension(extension)
|
||||
self._parse_options(**options)
|
||||
|
||||
def _parse_options(
|
||||
self,
|
||||
compression=0,
|
||||
use_checksum=False,
|
||||
float64=True,
|
||||
load_streaming=False,
|
||||
lazy_blob=False,
|
||||
):
|
||||
|
||||
# Validate compression
|
||||
if isinstance(compression, string_types):
|
||||
m = {"no": 0, "zlib": 1, "bz2": 2}
|
||||
compression = m.get(compression.lower(), compression)
|
||||
if compression not in (0, 1, 2):
|
||||
raise TypeError("Compression must be 0, 1, 2, " '"no", "zlib", or "bz2"')
|
||||
self._compression = compression
|
||||
|
||||
# Other encoding args
|
||||
self._use_checksum = bool(use_checksum)
|
||||
self._float64 = bool(float64)
|
||||
|
||||
# Decoding args
|
||||
self._load_streaming = bool(load_streaming)
|
||||
self._lazy_blob = bool(lazy_blob)
|
||||
|
||||
def add_extension(self, extension_class):
|
||||
""" Add an extension to this serializer instance, which must be
|
||||
a subclass of Extension. Can be used as a decorator.
|
||||
"""
|
||||
# Check class
|
||||
if not (
|
||||
isinstance(extension_class, type) and issubclass(extension_class, Extension)
|
||||
):
|
||||
raise TypeError("add_extension() expects a Extension class.")
|
||||
extension = extension_class()
|
||||
|
||||
# Get name
|
||||
name = extension.name
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("Extension name must be str.")
|
||||
if len(name) == 0 or len(name) > 250:
|
||||
raise NameError(
|
||||
"Extension names must be nonempty and shorter " "than 251 chars."
|
||||
)
|
||||
if name in self._extensions:
|
||||
logger.warning(
|
||||
'BSDF warning: overwriting extension "%s", '
|
||||
"consider removing first" % name
|
||||
)
|
||||
|
||||
# Get classes
|
||||
cls = extension.cls
|
||||
if not cls:
|
||||
clss = []
|
||||
elif isinstance(cls, (tuple, list)):
|
||||
clss = cls
|
||||
else:
|
||||
clss = [cls]
|
||||
for cls in clss:
|
||||
if not isinstance(cls, classtypes):
|
||||
raise TypeError("Extension classes must be types.")
|
||||
|
||||
# Store
|
||||
for cls in clss:
|
||||
self._extensions_by_cls[cls] = name, extension.encode
|
||||
self._extensions[name] = extension
|
||||
return extension_class
|
||||
|
||||
def remove_extension(self, name):
|
||||
""" Remove a converted by its unique name.
|
||||
"""
|
||||
if not isinstance(name, str):
|
||||
raise TypeError("Extension name must be str.")
|
||||
if name in self._extensions:
|
||||
self._extensions.pop(name)
|
||||
for cls in list(self._extensions_by_cls.keys()):
|
||||
if self._extensions_by_cls[cls][0] == name:
|
||||
self._extensions_by_cls.pop(cls)
|
||||
|
||||
def _encode(self, f, value, streams, ext_id):
|
||||
""" Main encoder function.
|
||||
"""
|
||||
x = encode_type_id
|
||||
|
||||
if value is None:
|
||||
f.write(x(b"v", ext_id)) # V for void
|
||||
elif value is True:
|
||||
f.write(x(b"y", ext_id)) # Y for yes
|
||||
elif value is False:
|
||||
f.write(x(b"n", ext_id)) # N for no
|
||||
elif isinstance(value, integer_types):
|
||||
if -32768 <= value <= 32767:
|
||||
f.write(x(b"h", ext_id) + spack("h", value)) # H for ...
|
||||
else:
|
||||
f.write(x(b"i", ext_id) + spack("<q", value)) # I for int
|
||||
elif isinstance(value, float):
|
||||
if self._float64:
|
||||
f.write(x(b"d", ext_id) + spack("<d", value)) # D for double
|
||||
else:
|
||||
f.write(x(b"f", ext_id) + spack("<f", value)) # f for float
|
||||
elif isinstance(value, unicode_types):
|
||||
bb = value.encode("UTF-8")
|
||||
f.write(x(b"s", ext_id) + lencode(len(bb))) # S for str
|
||||
f.write(bb)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
f.write(x(b"l", ext_id) + lencode(len(value))) # L for list
|
||||
for v in value:
|
||||
self._encode(f, v, streams, None)
|
||||
elif isinstance(value, dict):
|
||||
f.write(x(b"m", ext_id) + lencode(len(value))) # M for mapping
|
||||
for key, v in value.items():
|
||||
if PY3:
|
||||
assert key.isidentifier() # faster
|
||||
else: # pragma: no cover
|
||||
assert _isidentifier(key)
|
||||
# yield ' ' * indent + key
|
||||
name_b = key.encode("UTF-8")
|
||||
f.write(lencode(len(name_b)))
|
||||
f.write(name_b)
|
||||
self._encode(f, v, streams, None)
|
||||
elif isinstance(value, bytes):
|
||||
f.write(x(b"b", ext_id)) # B for blob
|
||||
blob = Blob(
|
||||
value, compression=self._compression, use_checksum=self._use_checksum
|
||||
)
|
||||
blob._to_file(f) # noqa
|
||||
elif isinstance(value, Blob):
|
||||
f.write(x(b"b", ext_id)) # B for blob
|
||||
value._to_file(f) # noqa
|
||||
elif isinstance(value, BaseStream):
|
||||
# Initialize the stream
|
||||
if value.mode != "w":
|
||||
raise ValueError("Cannot serialize a read-mode stream.")
|
||||
elif isinstance(value, ListStream):
|
||||
f.write(x(b"l", ext_id) + spack("<BQ", 255, 0)) # L for list
|
||||
else:
|
||||
raise TypeError("Only ListStream is supported")
|
||||
# Mark this as *the* stream, and activate the stream.
|
||||
# The save() function verifies this is the last written object.
|
||||
if len(streams) > 0:
|
||||
raise ValueError("Can only have one stream per file.")
|
||||
streams.append(value)
|
||||
value._activate(f, self._encode, self._decode) # noqa
|
||||
else:
|
||||
if ext_id is not None:
|
||||
raise ValueError(
|
||||
"Extension %s wronfully encodes object to another "
|
||||
"extension object (though it may encode to a list/dict "
|
||||
"that contains other extension objects)." % ext_id
|
||||
)
|
||||
# Try if the value is of a type we know
|
||||
ex = self._extensions_by_cls.get(value.__class__, None)
|
||||
# Maybe its a subclass of a type we know
|
||||
if ex is None:
|
||||
for name, c in self._extensions.items():
|
||||
if c.match(self, value):
|
||||
ex = name, c.encode
|
||||
break
|
||||
else:
|
||||
ex = None
|
||||
# Success or fail
|
||||
if ex is not None:
|
||||
ext_id2, extension_encode = ex
|
||||
self._encode(f, extension_encode(self, value), streams, ext_id2)
|
||||
else:
|
||||
t = (
|
||||
"Class %r is not a valid base BSDF type, nor is it "
|
||||
"handled by an extension."
|
||||
)
|
||||
raise TypeError(t % value.__class__.__name__)
|
||||
|
||||
def _decode(self, f):
|
||||
""" Main decoder function.
|
||||
"""
|
||||
|
||||
# Get value
|
||||
char = f.read(1)
|
||||
c = char.lower()
|
||||
|
||||
# Conversion (uppercase value identifiers signify converted values)
|
||||
if not char:
|
||||
raise EOFError()
|
||||
elif char != c:
|
||||
n = strunpack("<B", f.read(1))[0]
|
||||
# if n == 253: n = strunpack('<Q', f.read(8))[0] # noqa - noneed
|
||||
ext_id = f.read(n).decode("UTF-8")
|
||||
else:
|
||||
ext_id = None
|
||||
|
||||
if c == b"v":
|
||||
value = None
|
||||
elif c == b"y":
|
||||
value = True
|
||||
elif c == b"n":
|
||||
value = False
|
||||
elif c == b"h":
|
||||
value = strunpack("<h", f.read(2))[0]
|
||||
elif c == b"i":
|
||||
value = strunpack("<q", f.read(8))[0]
|
||||
elif c == b"f":
|
||||
value = strunpack("<f", f.read(4))[0]
|
||||
elif c == b"d":
|
||||
value = strunpack("<d", f.read(8))[0]
|
||||
elif c == b"s":
|
||||
n_s = strunpack("<B", f.read(1))[0]
|
||||
if n_s == 253:
|
||||
n_s = strunpack("<Q", f.read(8))[0] # noqa
|
||||
value = f.read(n_s).decode("UTF-8")
|
||||
elif c == b"l":
|
||||
n = strunpack("<B", f.read(1))[0]
|
||||
if n >= 254:
|
||||
# Streaming
|
||||
closed = n == 254
|
||||
n = strunpack("<Q", f.read(8))[0]
|
||||
if self._load_streaming:
|
||||
value = ListStream(n if closed else "r")
|
||||
value._activate(f, self._encode, self._decode) # noqa
|
||||
elif closed:
|
||||
value = [self._decode(f) for i in range(n)]
|
||||
else:
|
||||
value = []
|
||||
try:
|
||||
while True:
|
||||
value.append(self._decode(f))
|
||||
except EOFError:
|
||||
pass
|
||||
else:
|
||||
# Normal
|
||||
if n == 253:
|
||||
n = strunpack("<Q", f.read(8))[0] # noqa
|
||||
value = [self._decode(f) for i in range(n)]
|
||||
elif c == b"m":
|
||||
value = dict()
|
||||
n = strunpack("<B", f.read(1))[0]
|
||||
if n == 253:
|
||||
n = strunpack("<Q", f.read(8))[0] # noqa
|
||||
for i in range(n):
|
||||
n_name = strunpack("<B", f.read(1))[0]
|
||||
if n_name == 253:
|
||||
n_name = strunpack("<Q", f.read(8))[0] # noqa
|
||||
assert n_name > 0
|
||||
name = f.read(n_name).decode("UTF-8")
|
||||
value[name] = self._decode(f)
|
||||
elif c == b"b":
|
||||
if self._lazy_blob:
|
||||
value = Blob((f, True))
|
||||
else:
|
||||
blob = Blob((f, False))
|
||||
value = blob.get_bytes()
|
||||
else:
|
||||
raise RuntimeError("Parse error %r" % char)
|
||||
|
||||
# Convert value if we have an extension for it
|
||||
if ext_id is not None:
|
||||
extension = self._extensions.get(ext_id, None)
|
||||
if extension is not None:
|
||||
value = extension.decode(self, value)
|
||||
else:
|
||||
logger.warning("BSDF warning: no extension found for %r" % ext_id)
|
||||
|
||||
return value
|
||||
|
||||
def encode(self, ob):
|
||||
""" Save the given object to bytes.
|
||||
"""
|
||||
f = BytesIO()
|
||||
self.save(f, ob)
|
||||
return f.getvalue()
|
||||
|
||||
def save(self, f, ob):
|
||||
""" Write the given object to the given file object.
|
||||
"""
|
||||
f.write(b"BSDF")
|
||||
f.write(struct.pack("<B", VERSION[0]))
|
||||
f.write(struct.pack("<B", VERSION[1]))
|
||||
|
||||
# Prepare streaming, this list will have 0 or 1 item at the end
|
||||
streams = []
|
||||
|
||||
self._encode(f, ob, streams, None)
|
||||
|
||||
# Verify that stream object was at the end, and add initial elements
|
||||
if len(streams) > 0:
|
||||
stream = streams[0]
|
||||
if stream._start_pos != f.tell():
|
||||
raise ValueError(
|
||||
"The stream object must be " "the last object to be encoded."
|
||||
)
|
||||
|
||||
def decode(self, bb):
|
||||
""" Load the data structure that is BSDF-encoded in the given bytes.
|
||||
"""
|
||||
f = BytesIO(bb)
|
||||
return self.load(f)
|
||||
|
||||
def load(self, f):
|
||||
""" Load a BSDF-encoded object from the given file object.
|
||||
"""
|
||||
# Check magic string
|
||||
f4 = f.read(4)
|
||||
if f4 != b"BSDF":
|
||||
raise RuntimeError("This does not look like a BSDF file: %r" % f4)
|
||||
# Check version
|
||||
major_version = strunpack("<B", f.read(1))[0]
|
||||
minor_version = strunpack("<B", f.read(1))[0]
|
||||
file_version = "%i.%i" % (major_version, minor_version)
|
||||
if major_version != VERSION[0]: # major version should be 2
|
||||
t = (
|
||||
"Reading file with different major version (%s) "
|
||||
"from the implementation (%s)."
|
||||
)
|
||||
raise RuntimeError(t % (__version__, file_version))
|
||||
if minor_version > VERSION[1]: # minor should be < ours
|
||||
t = (
|
||||
"BSDF warning: reading file with higher minor version (%s) "
|
||||
"than the implementation (%s)."
|
||||
)
|
||||
logger.warning(t % (__version__, file_version))
|
||||
|
||||
return self._decode(f)
|
||||
|
||||
|
||||
# %% Streaming and blob-files
|
||||
|
||||
|
||||
class BaseStream(object):
|
||||
""" Base class for streams.
|
||||
"""
|
||||
|
||||
def __init__(self, mode="w"):
|
||||
self._i = 0
|
||||
self._count = -1
|
||||
if isinstance(mode, int):
|
||||
self._count = mode
|
||||
mode = "r"
|
||||
elif mode == "w":
|
||||
self._count = 0
|
||||
assert mode in ("r", "w")
|
||||
self._mode = mode
|
||||
self._f = None
|
||||
self._start_pos = 0
|
||||
|
||||
def _activate(self, file, encode_func, decode_func):
|
||||
if self._f is not None: # Associated with another write
|
||||
raise IOError("Stream object cannot be activated twice?")
|
||||
self._f = file
|
||||
self._start_pos = self._f.tell()
|
||||
self._encode = encode_func
|
||||
self._decode = decode_func
|
||||
|
||||
@property
|
||||
def mode(self):
|
||||
""" The mode of this stream: 'r' or 'w'.
|
||||
"""
|
||||
return self._mode
|
||||
|
||||
|
||||
class ListStream(BaseStream):
|
||||
""" A streamable list object used for writing or reading.
|
||||
In read mode, it can also be iterated over.
|
||||
"""
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
""" The number of elements in the stream (can be -1 for unclosed
|
||||
streams in read-mode).
|
||||
"""
|
||||
return self._count
|
||||
|
||||
@property
|
||||
def index(self):
|
||||
""" The current index of the element to read/write.
|
||||
"""
|
||||
return self._i
|
||||
|
||||
def append(self, item):
|
||||
""" Append an item to the streaming list. The object is immediately
|
||||
serialized and written to the underlying file.
|
||||
"""
|
||||
# if self._mode != 'w':
|
||||
# raise IOError('This ListStream is not in write mode.')
|
||||
if self._count != self._i:
|
||||
raise IOError("Can only append items to the end of the stream.")
|
||||
if self._f is None:
|
||||
raise IOError("List stream is not associated with a file yet.")
|
||||
if self._f.closed:
|
||||
raise IOError("Cannot stream to a close file.")
|
||||
self._encode(self._f, item, [self], None)
|
||||
self._i += 1
|
||||
self._count += 1
|
||||
|
||||
def close(self, unstream=False):
|
||||
""" Close the stream, marking the number of written elements. New
|
||||
elements may still be appended, but they won't be read during decoding.
|
||||
If ``unstream`` is False, the stream is turned into a regular list
|
||||
(not streaming).
|
||||
"""
|
||||
# if self._mode != 'w':
|
||||
# raise IOError('This ListStream is not in write mode.')
|
||||
if self._count != self._i:
|
||||
raise IOError("Can only close when at the end of the stream.")
|
||||
if self._f is None:
|
||||
raise IOError("ListStream is not associated with a file yet.")
|
||||
if self._f.closed:
|
||||
raise IOError("Cannot close a stream on a close file.")
|
||||
i = self._f.tell()
|
||||
self._f.seek(self._start_pos - 8 - 1)
|
||||
self._f.write(spack("<B", 253 if unstream else 254))
|
||||
self._f.write(spack("<Q", self._count))
|
||||
self._f.seek(i)
|
||||
|
||||
def next(self):
|
||||
""" Read and return the next element in the streaming list.
|
||||
Raises StopIteration if the stream is exhausted.
|
||||
"""
|
||||
if self._mode != "r":
|
||||
raise IOError("This ListStream in not in read mode.")
|
||||
if self._f is None:
|
||||
raise IOError("ListStream is not associated with a file yet.")
|
||||
if getattr(self._f, "closed", None): # not present on 2.7 http req :/
|
||||
raise IOError("Cannot read a stream from a close file.")
|
||||
if self._count >= 0:
|
||||
if self._i >= self._count:
|
||||
raise StopIteration()
|
||||
self._i += 1
|
||||
return self._decode(self._f)
|
||||
else:
|
||||
# This raises EOFError at some point.
|
||||
try:
|
||||
res = self._decode(self._f)
|
||||
self._i += 1
|
||||
return res
|
||||
except EOFError:
|
||||
self._count = self._i
|
||||
raise StopIteration()
|
||||
|
||||
def __iter__(self):
|
||||
if self._mode != "r":
|
||||
raise IOError("Cannot iterate: ListStream in not in read mode.")
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return self.next()
|
||||
|
||||
|
||||
class Blob(object):
|
||||
""" Object to represent a blob of bytes. When used to write a BSDF file,
|
||||
it's a wrapper for bytes plus properties such as what compression to apply.
|
||||
When used to read a BSDF file, it can be used to read the data lazily, and
|
||||
also modify the data if reading in 'r+' mode and the blob isn't compressed.
|
||||
"""
|
||||
|
||||
# For now, this does not allow re-sizing blobs (within the allocated size)
|
||||
# but this can be added later.
|
||||
|
||||
def __init__(self, bb, compression=0, extra_size=0, use_checksum=False):
|
||||
if isinstance(bb, bytes):
|
||||
self._f = None
|
||||
self.compressed = self._from_bytes(bb, compression)
|
||||
self.compression = compression
|
||||
self.allocated_size = self.used_size + extra_size
|
||||
self.use_checksum = use_checksum
|
||||
elif isinstance(bb, tuple) and len(bb) == 2 and hasattr(bb[0], "read"):
|
||||
self._f, allow_seek = bb
|
||||
self.compressed = None
|
||||
self._from_file(self._f, allow_seek)
|
||||
self._modified = False
|
||||
else:
|
||||
raise TypeError("Wrong argument to create Blob.")
|
||||
|
||||
def _from_bytes(self, value, compression):
|
||||
""" When used to wrap bytes in a blob.
|
||||
"""
|
||||
if compression == 0:
|
||||
compressed = value
|
||||
elif compression == 1:
|
||||
compressed = zlib.compress(value, 9)
|
||||
elif compression == 2:
|
||||
compressed = bz2.compress(value, 9)
|
||||
else: # pragma: no cover
|
||||
assert False, "Unknown compression identifier"
|
||||
|
||||
self.data_size = len(value)
|
||||
self.used_size = len(compressed)
|
||||
return compressed
|
||||
|
||||
def _to_file(self, f):
|
||||
""" Private friend method called by encoder to write a blob to a file.
|
||||
"""
|
||||
# Write sizes - write at least in a size that allows resizing
|
||||
if self.allocated_size <= 250 and self.compression == 0:
|
||||
f.write(spack("<B", self.allocated_size))
|
||||
f.write(spack("<B", self.used_size))
|
||||
f.write(lencode(self.data_size))
|
||||
else:
|
||||
f.write(spack("<BQ", 253, self.allocated_size))
|
||||
f.write(spack("<BQ", 253, self.used_size))
|
||||
f.write(spack("<BQ", 253, self.data_size))
|
||||
# Compression and checksum
|
||||
f.write(spack("B", self.compression))
|
||||
if self.use_checksum:
|
||||
f.write(b"\xff" + hashlib.md5(self.compressed).digest())
|
||||
else:
|
||||
f.write(b"\x00")
|
||||
# Byte alignment (only necessary for uncompressed data)
|
||||
if self.compression == 0:
|
||||
alignment = 8 - (f.tell() + 1) % 8 # +1 for the byte to write
|
||||
f.write(spack("<B", alignment)) # padding for byte alignment
|
||||
f.write(b"\x00" * alignment)
|
||||
else:
|
||||
f.write(spack("<B", 0))
|
||||
# The actual data and extra space
|
||||
f.write(self.compressed)
|
||||
f.write(b"\x00" * (self.allocated_size - self.used_size))
|
||||
|
||||
def _from_file(self, f, allow_seek):
|
||||
""" Used when a blob is read by the decoder.
|
||||
"""
|
||||
# Read blob header data (5 to 42 bytes)
|
||||
# Size
|
||||
allocated_size = strunpack("<B", f.read(1))[0]
|
||||
if allocated_size == 253:
|
||||
allocated_size = strunpack("<Q", f.read(8))[0] # noqa
|
||||
used_size = strunpack("<B", f.read(1))[0]
|
||||
if used_size == 253:
|
||||
used_size = strunpack("<Q", f.read(8))[0] # noqa
|
||||
data_size = strunpack("<B", f.read(1))[0]
|
||||
if data_size == 253:
|
||||
data_size = strunpack("<Q", f.read(8))[0] # noqa
|
||||
# Compression and checksum
|
||||
compression = strunpack("<B", f.read(1))[0]
|
||||
has_checksum = strunpack("<B", f.read(1))[0]
|
||||
if has_checksum:
|
||||
checksum = f.read(16)
|
||||
# Skip alignment
|
||||
alignment = strunpack("<B", f.read(1))[0]
|
||||
f.read(alignment)
|
||||
# Get or skip data + extra space
|
||||
if allow_seek:
|
||||
self.start_pos = f.tell()
|
||||
self.end_pos = self.start_pos + used_size
|
||||
f.seek(self.start_pos + allocated_size)
|
||||
else:
|
||||
self.start_pos = None
|
||||
self.end_pos = None
|
||||
self.compressed = f.read(used_size)
|
||||
f.read(allocated_size - used_size)
|
||||
# Store info
|
||||
self.alignment = alignment
|
||||
self.compression = compression
|
||||
self.use_checksum = checksum if has_checksum else None
|
||||
self.used_size = used_size
|
||||
self.allocated_size = allocated_size
|
||||
self.data_size = data_size
|
||||
|
||||
def seek(self, p):
|
||||
""" Seek to the given position (relative to the blob start).
|
||||
"""
|
||||
if self._f is None:
|
||||
raise RuntimeError(
|
||||
"Cannot seek in a blob " "that is not created by the BSDF decoder."
|
||||
)
|
||||
if p < 0:
|
||||
p = self.allocated_size + p
|
||||
if p < 0 or p > self.allocated_size:
|
||||
raise IOError("Seek beyond blob boundaries.")
|
||||
self._f.seek(self.start_pos + p)
|
||||
|
||||
def tell(self):
|
||||
""" Get the current file pointer position (relative to the blob start).
|
||||
"""
|
||||
if self._f is None:
|
||||
raise RuntimeError(
|
||||
"Cannot tell in a blob " "that is not created by the BSDF decoder."
|
||||
)
|
||||
return self._f.tell() - self.start_pos
|
||||
|
||||
def write(self, bb):
|
||||
""" Write bytes to the blob.
|
||||
"""
|
||||
if self._f is None:
|
||||
raise RuntimeError(
|
||||
"Cannot write in a blob " "that is not created by the BSDF decoder."
|
||||
)
|
||||
if self.compression:
|
||||
raise IOError("Cannot arbitrarily write in compressed blob.")
|
||||
if self._f.tell() + len(bb) > self.end_pos:
|
||||
raise IOError("Write beyond blob boundaries.")
|
||||
self._modified = True
|
||||
return self._f.write(bb)
|
||||
|
||||
def read(self, n):
|
||||
""" Read n bytes from the blob.
|
||||
"""
|
||||
if self._f is None:
|
||||
raise RuntimeError(
|
||||
"Cannot read in a blob " "that is not created by the BSDF decoder."
|
||||
)
|
||||
if self.compression:
|
||||
raise IOError("Cannot arbitrarily read in compressed blob.")
|
||||
if self._f.tell() + n > self.end_pos:
|
||||
raise IOError("Read beyond blob boundaries.")
|
||||
return self._f.read(n)
|
||||
|
||||
def get_bytes(self):
|
||||
""" Get the contents of the blob as bytes.
|
||||
"""
|
||||
if self.compressed is not None:
|
||||
compressed = self.compressed
|
||||
else:
|
||||
i = self._f.tell()
|
||||
self.seek(0)
|
||||
compressed = self._f.read(self.used_size)
|
||||
self._f.seek(i)
|
||||
if self.compression == 0:
|
||||
value = compressed
|
||||
elif self.compression == 1:
|
||||
value = zlib.decompress(compressed)
|
||||
elif self.compression == 2:
|
||||
value = bz2.decompress(compressed)
|
||||
else: # pragma: no cover
|
||||
raise RuntimeError("Invalid compression %i" % self.compression)
|
||||
return value
|
||||
|
||||
def update_checksum(self):
|
||||
""" Reset the blob's checksum if present. Call this after modifying
|
||||
the data.
|
||||
"""
|
||||
# or ... should the presence of a checksum mean that data is proteced?
|
||||
if self.use_checksum and self._modified:
|
||||
self.seek(0)
|
||||
compressed = self._f.read(self.used_size)
|
||||
self._f.seek(self.start_pos - self.alignment - 1 - 16)
|
||||
self._f.write(hashlib.md5(compressed).digest())
|
||||
|
||||
|
||||
# %% High-level functions
|
||||
|
||||
|
||||
def encode(ob, extensions=None, **options):
|
||||
""" Save (BSDF-encode) the given object to bytes.
|
||||
See `BSDFSerializer` for details on extensions and options.
|
||||
"""
|
||||
s = BsdfSerializer(extensions, **options)
|
||||
return s.encode(ob)
|
||||
|
||||
|
||||
def save(f, ob, extensions=None, **options):
|
||||
""" Save (BSDF-encode) the given object to the given filename or
|
||||
file object. See` BSDFSerializer` for details on extensions and options.
|
||||
"""
|
||||
s = BsdfSerializer(extensions, **options)
|
||||
if isinstance(f, string_types):
|
||||
with open(f, "wb") as fp:
|
||||
return s.save(fp, ob)
|
||||
else:
|
||||
return s.save(f, ob)
|
||||
|
||||
|
||||
def decode(bb, extensions=None, **options):
|
||||
""" Load a (BSDF-encoded) structure from bytes.
|
||||
See `BSDFSerializer` for details on extensions and options.
|
||||
"""
|
||||
s = BsdfSerializer(extensions, **options)
|
||||
return s.decode(bb)
|
||||
|
||||
|
||||
def load(f, extensions=None, **options):
|
||||
""" Load a (BSDF-encoded) structure from the given filename or file object.
|
||||
See `BSDFSerializer` for details on extensions and options.
|
||||
"""
|
||||
s = BsdfSerializer(extensions, **options)
|
||||
if isinstance(f, string_types):
|
||||
if f.startswith(("~/", "~\\")): # pragma: no cover
|
||||
f = os.path.expanduser(f)
|
||||
with open(f, "rb") as fp:
|
||||
return s.load(fp)
|
||||
else:
|
||||
return s.load(f)
|
||||
|
||||
|
||||
# Aliases for json compat
|
||||
loads = decode
|
||||
dumps = encode
|
||||
|
||||
|
||||
# %% Standard extensions
|
||||
|
||||
# Defining extensions as a dict would be more compact and feel lighter, but
|
||||
# that would only allow lambdas, which is too limiting, e.g. for ndarray
|
||||
# extension.
|
||||
|
||||
|
||||
class Extension(object):
|
||||
""" Base class to implement BSDF extensions for special data types.
|
||||
|
||||
Extension classes are provided to the BSDF serializer, which
|
||||
instantiates the class. That way, the extension can be somewhat dynamic:
|
||||
e.g. the NDArrayExtension exposes the ndarray class only when numpy
|
||||
is imported.
|
||||
|
||||
A extension instance must have two attributes. These can be attribiutes of
|
||||
the class, or of the instance set in ``__init__()``:
|
||||
|
||||
* name (str): the name by which encoded values will be identified.
|
||||
* cls (type): the type (or list of types) to match values with.
|
||||
This is optional, but it makes the encoder select extensions faster.
|
||||
|
||||
Further, it needs 3 methods:
|
||||
|
||||
* `match(serializer, value) -> bool`: return whether the extension can
|
||||
convert the given value. The default is ``isinstance(value, self.cls)``.
|
||||
* `encode(serializer, value) -> encoded_value`: the function to encode a
|
||||
value to more basic data types.
|
||||
* `decode(serializer, encoded_value) -> value`: the function to decode an
|
||||
encoded value back to its intended representation.
|
||||
|
||||
"""
|
||||
|
||||
name = ""
|
||||
cls = ()
|
||||
|
||||
def __repr__(self):
|
||||
return "<BSDF extension %r at 0x%s>" % (self.name, hex(id(self)))
|
||||
|
||||
def match(self, s, v):
|
||||
return isinstance(v, self.cls)
|
||||
|
||||
def encode(self, s, v):
|
||||
raise NotImplementedError()
|
||||
|
||||
def decode(self, s, v):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ComplexExtension(Extension):
|
||||
|
||||
name = "c"
|
||||
cls = complex
|
||||
|
||||
def encode(self, s, v):
|
||||
return (v.real, v.imag)
|
||||
|
||||
def decode(self, s, v):
|
||||
return complex(v[0], v[1])
|
||||
|
||||
|
||||
class NDArrayExtension(Extension):
|
||||
|
||||
name = "ndarray"
|
||||
|
||||
def __init__(self):
|
||||
if "numpy" in sys.modules:
|
||||
import numpy as np
|
||||
|
||||
self.cls = np.ndarray
|
||||
|
||||
def match(self, s, v): # pragma: no cover - e.g. work for nd arrays in JS
|
||||
return hasattr(v, "shape") and hasattr(v, "dtype") and hasattr(v, "tobytes")
|
||||
|
||||
def encode(self, s, v):
|
||||
return dict(shape=v.shape, dtype=text_type(v.dtype), data=v.tobytes())
|
||||
|
||||
def decode(self, s, v):
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError: # pragma: no cover
|
||||
return v
|
||||
a = np.frombuffer(v["data"], dtype=v["dtype"])
|
||||
a.shape = v["shape"]
|
||||
return a
|
||||
|
||||
|
||||
standard_extensions = [ComplexExtension, NDArrayExtension]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Invoke CLI
|
||||
import bsdf_cli
|
||||
|
||||
bsdf_cli.main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue