from pyrsistent._checked_types import CheckedType, _restore_pickle, InvariantException, store_invariants from pyrsistent._field_common import ( set_fields, check_type, is_field_ignore_extra_complaint, PFIELD_NO_INITIAL, serialize, check_global_invariants ) from pyrsistent._pmap import PMap, pmap class _PRecordMeta(type): def __new__(mcs, name, bases, dct): set_fields(dct, bases, name='_precord_fields') store_invariants(dct, bases, '_precord_invariants', '__invariant__') dct['_precord_mandatory_fields'] = \ set(name for name, field in dct['_precord_fields'].items() if field.mandatory) dct['_precord_initial_values'] = \ dict((k, field.initial) for k, field in dct['_precord_fields'].items() if field.initial is not PFIELD_NO_INITIAL) dct['__slots__'] = () return super(_PRecordMeta, mcs).__new__(mcs, name, bases, dct) class PRecord(PMap, CheckedType, metaclass=_PRecordMeta): """ A PRecord is a PMap with a fixed set of specified fields. Records are declared as python classes inheriting from PRecord. Because it is a PMap it has full support for all Mapping methods such as iteration and element access using subscript notation. More documentation and examples of PRecord usage is available at https://github.com/tobgu/pyrsistent """ def __new__(cls, **kwargs): # Hack total! If these two special attributes exist that means we can create # ourselves. Otherwise we need to go through the Evolver to create the structures # for us. if '_precord_size' in kwargs and '_precord_buckets' in kwargs: return super(PRecord, cls).__new__(cls, kwargs['_precord_size'], kwargs['_precord_buckets']) factory_fields = kwargs.pop('_factory_fields', None) ignore_extra = kwargs.pop('_ignore_extra', False) initial_values = kwargs if cls._precord_initial_values: initial_values = dict((k, v() if callable(v) else v) for k, v in cls._precord_initial_values.items()) initial_values.update(kwargs) e = _PRecordEvolver(cls, pmap(), _factory_fields=factory_fields, _ignore_extra=ignore_extra) for k, v in initial_values.items(): e[k] = v return e.persistent() def set(self, *args, **kwargs): """ Set a field in the record. This set function differs slightly from that in the PMap class. First of all it accepts key-value pairs. Second it accepts multiple key-value pairs to perform one, atomic, update of multiple fields. """ # The PRecord set() can accept kwargs since all fields that have been declared are # valid python identifiers. Also allow multiple fields to be set in one operation. if args: return super(PRecord, self).set(args[0], args[1]) return self.update(kwargs) def evolver(self): """ Returns an evolver of this object. """ return _PRecordEvolver(self.__class__, self) def __repr__(self): return "{0}({1})".format(self.__class__.__name__, ', '.join('{0}={1}'.format(k, repr(v)) for k, v in self.items())) @classmethod def create(cls, kwargs, _factory_fields=None, ignore_extra=False): """ Factory method. Will create a new PRecord of the current type and assign the values specified in kwargs. :param ignore_extra: A boolean which when set to True will ignore any keys which appear in kwargs that are not in the set of fields on the PRecord. """ if isinstance(kwargs, cls): return kwargs if ignore_extra: kwargs = {k: kwargs[k] for k in cls._precord_fields if k in kwargs} return cls(_factory_fields=_factory_fields, _ignore_extra=ignore_extra, **kwargs) def __reduce__(self): # Pickling support return _restore_pickle, (self.__class__, dict(self),) def serialize(self, format=None): """ Serialize the current PRecord using custom serializer functions for fields where such have been supplied. """ return dict((k, serialize(self._precord_fields[k].serializer, format, v)) for k, v in self.items()) class _PRecordEvolver(PMap._Evolver): __slots__ = ('_destination_cls', '_invariant_error_codes', '_missing_fields', '_factory_fields', '_ignore_extra') def __init__(self, cls, original_pmap, _factory_fields=None, _ignore_extra=False): super(_PRecordEvolver, self).__init__(original_pmap) self._destination_cls = cls self._invariant_error_codes = [] self._missing_fields = [] self._factory_fields = _factory_fields self._ignore_extra = _ignore_extra def __setitem__(self, key, original_value): self.set(key, original_value) def set(self, key, original_value): field = self._destination_cls._precord_fields.get(key) if field: if self._factory_fields is None or field in self._factory_fields: try: if is_field_ignore_extra_complaint(PRecord, field, self._ignore_extra): value = field.factory(original_value, ignore_extra=self._ignore_extra) else: value = field.factory(original_value) except InvariantException as e: self._invariant_error_codes += e.invariant_errors self._missing_fields += e.missing_fields return self else: value = original_value check_type(self._destination_cls, field, key, value) is_ok, error_code = field.invariant(value) if not is_ok: self._invariant_error_codes.append(error_code) return super(_PRecordEvolver, self).set(key, value) else: raise AttributeError("'{0}' is not among the specified fields for {1}".format(key, self._destination_cls.__name__)) def persistent(self): cls = self._destination_cls is_dirty = self.is_dirty() pm = super(_PRecordEvolver, self).persistent() if is_dirty or not isinstance(pm, cls): result = cls(_precord_buckets=pm._buckets, _precord_size=pm._size) else: result = pm if cls._precord_mandatory_fields: self._missing_fields += tuple('{0}.{1}'.format(cls.__name__, f) for f in (cls._precord_mandatory_fields - set(result.keys()))) if self._invariant_error_codes or self._missing_fields: raise InvariantException(tuple(self._invariant_error_codes), tuple(self._missing_fields), 'Field invariant failed') check_global_invariants(result, cls._precord_invariants) return result