#!/usr/bin/env python # This file is distributed under the terms of the 2-clause BSD License. # Copyright (c) 2017-2018, Almar Klein """ Python implementation of the Binary Structured Data Format (BSDF). BSDF is a binary format for serializing structured (scientific) data. See http://bsdf.io for more information. This is the reference implementation, which is relatively relatively sophisticated, providing e.g. lazy loading of blobs and streamed reading/writing. A simpler Python implementation is available as ``bsdf_lite.py``. This module has no dependencies and works on Python 2.7 and 3.4+. Note: on Legacy Python (Python 2.7), non-Unicode strings are encoded as bytes. """ # todo: in 2020, remove six stuff, __future__ and _isidentifier # todo: in 2020, remove 'utf-8' args to encode/decode; it's faster from __future__ import absolute_import, division, print_function import bz2 import hashlib import logging import os import re import struct import sys import types import zlib from io import BytesIO logger = logging.getLogger(__name__) # Notes on versioning: the major and minor numbers correspond to the # BSDF format version. The major number if increased when backward # incompatible changes are introduced. An implementation must raise an # exception when the file being read has a higher major version. The # minor number is increased when new backward compatible features are # introduced. An implementation must display a warning when the file # being read has a higher minor version. The patch version is increased # for subsequent releases of the implementation. VERSION = 2, 1, 2 __version__ = ".".join(str(i) for i in VERSION) # %% The encoder and decoder implementation # From six.py PY3 = sys.version_info[0] >= 3 if PY3: text_type = str string_types = str unicode_types = str integer_types = int classtypes = type else: # pragma: no cover logging.basicConfig() # avoid "no handlers found" error text_type = unicode # noqa string_types = basestring # noqa unicode_types = unicode # noqa integer_types = (int, long) # noqa classtypes = type, types.ClassType # Shorthands spack = struct.pack strunpack = struct.unpack def lencode(x): """ Encode an unsigned integer into a variable sized blob of bytes. """ # We could support 16 bit and 32 bit as well, but the gain is low, since # 9 bytes for collections with over 250 elements is marginal anyway. if x <= 250: return spack(" extension self._extensions_by_cls = {} # cls -> (name, extension.encode) if extensions is None: extensions = standard_extensions for extension in extensions: self.add_extension(extension) self._parse_options(**options) def _parse_options( self, compression=0, use_checksum=False, float64=True, load_streaming=False, lazy_blob=False, ): # Validate compression if isinstance(compression, string_types): m = {"no": 0, "zlib": 1, "bz2": 2} compression = m.get(compression.lower(), compression) if compression not in (0, 1, 2): raise TypeError("Compression must be 0, 1, 2, " '"no", "zlib", or "bz2"') self._compression = compression # Other encoding args self._use_checksum = bool(use_checksum) self._float64 = bool(float64) # Decoding args self._load_streaming = bool(load_streaming) self._lazy_blob = bool(lazy_blob) def add_extension(self, extension_class): """ Add an extension to this serializer instance, which must be a subclass of Extension. Can be used as a decorator. """ # Check class if not ( isinstance(extension_class, type) and issubclass(extension_class, Extension) ): raise TypeError("add_extension() expects a Extension class.") extension = extension_class() # Get name name = extension.name if not isinstance(name, str): raise TypeError("Extension name must be str.") if len(name) == 0 or len(name) > 250: raise NameError( "Extension names must be nonempty and shorter " "than 251 chars." ) if name in self._extensions: logger.warning( 'BSDF warning: overwriting extension "%s", ' "consider removing first" % name ) # Get classes cls = extension.cls if not cls: clss = [] elif isinstance(cls, (tuple, list)): clss = cls else: clss = [cls] for cls in clss: if not isinstance(cls, classtypes): raise TypeError("Extension classes must be types.") # Store for cls in clss: self._extensions_by_cls[cls] = name, extension.encode self._extensions[name] = extension return extension_class def remove_extension(self, name): """ Remove a converted by its unique name. """ if not isinstance(name, str): raise TypeError("Extension name must be str.") if name in self._extensions: self._extensions.pop(name) for cls in list(self._extensions_by_cls.keys()): if self._extensions_by_cls[cls][0] == name: self._extensions_by_cls.pop(cls) def _encode(self, f, value, streams, ext_id): """ Main encoder function. """ x = encode_type_id if value is None: f.write(x(b"v", ext_id)) # V for void elif value is True: f.write(x(b"y", ext_id)) # Y for yes elif value is False: f.write(x(b"n", ext_id)) # N for no elif isinstance(value, integer_types): if -32768 <= value <= 32767: f.write(x(b"h", ext_id) + spack("h", value)) # H for ... else: f.write(x(b"i", ext_id) + spack(" 0: raise ValueError("Can only have one stream per file.") streams.append(value) value._activate(f, self._encode, self._decode) # noqa else: if ext_id is not None: raise ValueError( "Extension %s wronfully encodes object to another " "extension object (though it may encode to a list/dict " "that contains other extension objects)." % ext_id ) # Try if the value is of a type we know ex = self._extensions_by_cls.get(value.__class__, None) # Maybe its a subclass of a type we know if ex is None: for name, c in self._extensions.items(): if c.match(self, value): ex = name, c.encode break else: ex = None # Success or fail if ex is not None: ext_id2, extension_encode = ex self._encode(f, extension_encode(self, value), streams, ext_id2) else: t = ( "Class %r is not a valid base BSDF type, nor is it " "handled by an extension." ) raise TypeError(t % value.__class__.__name__) def _decode(self, f): """ Main decoder function. """ # Get value char = f.read(1) c = char.lower() # Conversion (uppercase value identifiers signify converted values) if not char: raise EOFError() elif char != c: n = strunpack("= 254: # Streaming closed = n == 254 n = strunpack(" 0 name = f.read(n_name).decode("UTF-8") value[name] = self._decode(f) elif c == b"b": if self._lazy_blob: value = Blob((f, True)) else: blob = Blob((f, False)) value = blob.get_bytes() else: raise RuntimeError("Parse error %r" % char) # Convert value if we have an extension for it if ext_id is not None: extension = self._extensions.get(ext_id, None) if extension is not None: value = extension.decode(self, value) else: logger.warning("BSDF warning: no extension found for %r" % ext_id) return value def encode(self, ob): """ Save the given object to bytes. """ f = BytesIO() self.save(f, ob) return f.getvalue() def save(self, f, ob): """ Write the given object to the given file object. """ f.write(b"BSDF") f.write(struct.pack(" 0: stream = streams[0] if stream._start_pos != f.tell(): raise ValueError( "The stream object must be " "the last object to be encoded." ) def decode(self, bb): """ Load the data structure that is BSDF-encoded in the given bytes. """ f = BytesIO(bb) return self.load(f) def load(self, f): """ Load a BSDF-encoded object from the given file object. """ # Check magic string f4 = f.read(4) if f4 != b"BSDF": raise RuntimeError("This does not look like a BSDF file: %r" % f4) # Check version major_version = strunpack(" VERSION[1]: # minor should be < ours t = ( "BSDF warning: reading file with higher minor version (%s) " "than the implementation (%s)." ) logger.warning(t % (__version__, file_version)) return self._decode(f) # %% Streaming and blob-files class BaseStream(object): """ Base class for streams. """ def __init__(self, mode="w"): self._i = 0 self._count = -1 if isinstance(mode, int): self._count = mode mode = "r" elif mode == "w": self._count = 0 assert mode in ("r", "w") self._mode = mode self._f = None self._start_pos = 0 def _activate(self, file, encode_func, decode_func): if self._f is not None: # Associated with another write raise IOError("Stream object cannot be activated twice?") self._f = file self._start_pos = self._f.tell() self._encode = encode_func self._decode = decode_func @property def mode(self): """ The mode of this stream: 'r' or 'w'. """ return self._mode class ListStream(BaseStream): """ A streamable list object used for writing or reading. In read mode, it can also be iterated over. """ @property def count(self): """ The number of elements in the stream (can be -1 for unclosed streams in read-mode). """ return self._count @property def index(self): """ The current index of the element to read/write. """ return self._i def append(self, item): """ Append an item to the streaming list. The object is immediately serialized and written to the underlying file. """ # if self._mode != 'w': # raise IOError('This ListStream is not in write mode.') if self._count != self._i: raise IOError("Can only append items to the end of the stream.") if self._f is None: raise IOError("List stream is not associated with a file yet.") if self._f.closed: raise IOError("Cannot stream to a close file.") self._encode(self._f, item, [self], None) self._i += 1 self._count += 1 def close(self, unstream=False): """ Close the stream, marking the number of written elements. New elements may still be appended, but they won't be read during decoding. If ``unstream`` is False, the stream is turned into a regular list (not streaming). """ # if self._mode != 'w': # raise IOError('This ListStream is not in write mode.') if self._count != self._i: raise IOError("Can only close when at the end of the stream.") if self._f is None: raise IOError("ListStream is not associated with a file yet.") if self._f.closed: raise IOError("Cannot close a stream on a close file.") i = self._f.tell() self._f.seek(self._start_pos - 8 - 1) self._f.write(spack("= 0: if self._i >= self._count: raise StopIteration() self._i += 1 return self._decode(self._f) else: # This raises EOFError at some point. try: res = self._decode(self._f) self._i += 1 return res except EOFError: self._count = self._i raise StopIteration() def __iter__(self): if self._mode != "r": raise IOError("Cannot iterate: ListStream in not in read mode.") return self def __next__(self): return self.next() class Blob(object): """ Object to represent a blob of bytes. When used to write a BSDF file, it's a wrapper for bytes plus properties such as what compression to apply. When used to read a BSDF file, it can be used to read the data lazily, and also modify the data if reading in 'r+' mode and the blob isn't compressed. """ # For now, this does not allow re-sizing blobs (within the allocated size) # but this can be added later. def __init__(self, bb, compression=0, extra_size=0, use_checksum=False): if isinstance(bb, bytes): self._f = None self.compressed = self._from_bytes(bb, compression) self.compression = compression self.allocated_size = self.used_size + extra_size self.use_checksum = use_checksum elif isinstance(bb, tuple) and len(bb) == 2 and hasattr(bb[0], "read"): self._f, allow_seek = bb self.compressed = None self._from_file(self._f, allow_seek) self._modified = False else: raise TypeError("Wrong argument to create Blob.") def _from_bytes(self, value, compression): """ When used to wrap bytes in a blob. """ if compression == 0: compressed = value elif compression == 1: compressed = zlib.compress(value, 9) elif compression == 2: compressed = bz2.compress(value, 9) else: # pragma: no cover assert False, "Unknown compression identifier" self.data_size = len(value) self.used_size = len(compressed) return compressed def _to_file(self, f): """ Private friend method called by encoder to write a blob to a file. """ # Write sizes - write at least in a size that allows resizing if self.allocated_size <= 250 and self.compression == 0: f.write(spack(" self.allocated_size: raise IOError("Seek beyond blob boundaries.") self._f.seek(self.start_pos + p) def tell(self): """ Get the current file pointer position (relative to the blob start). """ if self._f is None: raise RuntimeError( "Cannot tell in a blob " "that is not created by the BSDF decoder." ) return self._f.tell() - self.start_pos def write(self, bb): """ Write bytes to the blob. """ if self._f is None: raise RuntimeError( "Cannot write in a blob " "that is not created by the BSDF decoder." ) if self.compression: raise IOError("Cannot arbitrarily write in compressed blob.") if self._f.tell() + len(bb) > self.end_pos: raise IOError("Write beyond blob boundaries.") self._modified = True return self._f.write(bb) def read(self, n): """ Read n bytes from the blob. """ if self._f is None: raise RuntimeError( "Cannot read in a blob " "that is not created by the BSDF decoder." ) if self.compression: raise IOError("Cannot arbitrarily read in compressed blob.") if self._f.tell() + n > self.end_pos: raise IOError("Read beyond blob boundaries.") return self._f.read(n) def get_bytes(self): """ Get the contents of the blob as bytes. """ if self.compressed is not None: compressed = self.compressed else: i = self._f.tell() self.seek(0) compressed = self._f.read(self.used_size) self._f.seek(i) if self.compression == 0: value = compressed elif self.compression == 1: value = zlib.decompress(compressed) elif self.compression == 2: value = bz2.decompress(compressed) else: # pragma: no cover raise RuntimeError("Invalid compression %i" % self.compression) return value def update_checksum(self): """ Reset the blob's checksum if present. Call this after modifying the data. """ # or ... should the presence of a checksum mean that data is proteced? if self.use_checksum and self._modified: self.seek(0) compressed = self._f.read(self.used_size) self._f.seek(self.start_pos - self.alignment - 1 - 16) self._f.write(hashlib.md5(compressed).digest()) # %% High-level functions def encode(ob, extensions=None, **options): """ Save (BSDF-encode) the given object to bytes. See `BSDFSerializer` for details on extensions and options. """ s = BsdfSerializer(extensions, **options) return s.encode(ob) def save(f, ob, extensions=None, **options): """ Save (BSDF-encode) the given object to the given filename or file object. See` BSDFSerializer` for details on extensions and options. """ s = BsdfSerializer(extensions, **options) if isinstance(f, string_types): with open(f, "wb") as fp: return s.save(fp, ob) else: return s.save(f, ob) def decode(bb, extensions=None, **options): """ Load a (BSDF-encoded) structure from bytes. See `BSDFSerializer` for details on extensions and options. """ s = BsdfSerializer(extensions, **options) return s.decode(bb) def load(f, extensions=None, **options): """ Load a (BSDF-encoded) structure from the given filename or file object. See `BSDFSerializer` for details on extensions and options. """ s = BsdfSerializer(extensions, **options) if isinstance(f, string_types): if f.startswith(("~/", "~\\")): # pragma: no cover f = os.path.expanduser(f) with open(f, "rb") as fp: return s.load(fp) else: return s.load(f) # Aliases for json compat loads = decode dumps = encode # %% Standard extensions # Defining extensions as a dict would be more compact and feel lighter, but # that would only allow lambdas, which is too limiting, e.g. for ndarray # extension. class Extension(object): """ Base class to implement BSDF extensions for special data types. Extension classes are provided to the BSDF serializer, which instantiates the class. That way, the extension can be somewhat dynamic: e.g. the NDArrayExtension exposes the ndarray class only when numpy is imported. A extension instance must have two attributes. These can be attribiutes of the class, or of the instance set in ``__init__()``: * name (str): the name by which encoded values will be identified. * cls (type): the type (or list of types) to match values with. This is optional, but it makes the encoder select extensions faster. Further, it needs 3 methods: * `match(serializer, value) -> bool`: return whether the extension can convert the given value. The default is ``isinstance(value, self.cls)``. * `encode(serializer, value) -> encoded_value`: the function to encode a value to more basic data types. * `decode(serializer, encoded_value) -> value`: the function to decode an encoded value back to its intended representation. """ name = "" cls = () def __repr__(self): return "" % (self.name, hex(id(self))) def match(self, s, v): return isinstance(v, self.cls) def encode(self, s, v): raise NotImplementedError() def decode(self, s, v): raise NotImplementedError() class ComplexExtension(Extension): name = "c" cls = complex def encode(self, s, v): return (v.real, v.imag) def decode(self, s, v): return complex(v[0], v[1]) class NDArrayExtension(Extension): name = "ndarray" def __init__(self): if "numpy" in sys.modules: import numpy as np self.cls = np.ndarray def match(self, s, v): # pragma: no cover - e.g. work for nd arrays in JS return hasattr(v, "shape") and hasattr(v, "dtype") and hasattr(v, "tobytes") def encode(self, s, v): return dict(shape=v.shape, dtype=text_type(v.dtype), data=v.tobytes()) def decode(self, s, v): try: import numpy as np except ImportError: # pragma: no cover return v a = np.frombuffer(v["data"], dtype=v["dtype"]) a.shape = v["shape"] return a standard_extensions = [ComplexExtension, NDArrayExtension] if __name__ == "__main__": # Invoke CLI import bsdf_cli bsdf_cli.main()