Source code for abilian.core.sqlalchemy

# coding=utf-8
""" Additional data types for sqlalchemy
"""
from __future__ import absolute_import, print_function, division

import sys
import logging
from distutils.version import StrictVersion
import pkg_resources
from functools import partial
import json
import uuid

import pytz
import babel
from flask_sqlalchemy import SQLAlchemy as SAExtension
import sqlalchemy as sa
from sqlalchemy.ext.mutable import Mutable

from .logging import patch_logger


logger = logging.getLogger(__name__)

FLASK_SA_VERSION = pkg_resources.get_distribution('Flask-SQLAlchemy').version


@sa.event.listens_for(sa.pool.Pool, "checkout")
[docs]def ping_connection(dbapi_connection, connection_record, connection_proxy): """ Ensure connections are valid. From: `http://docs.sqlalchemy.org/en/rel_0_8/core/pooling.html` In case db has been restarted pool may return invalid connections. """ cursor = dbapi_connection.cursor() try: cursor.execute("SELECT 1") except: # optional - dispose the whole pool # instead of invalidating one at a time # connection_proxy._pool.dispose() # raise DisconnectionError - pool will try # connecting again up to three times before raising. raise sa.exc.DisconnectionError() cursor.close()
[docs]class AbilianBaseSAExtension(SAExtension): """ Base subclass of :class:`flask_sqlalchemy.SQLAlchemy`. Add our custom driver hacks. """
[docs] def apply_driver_hacks(self, app, info, options): SAExtension.apply_driver_hacks(self, app, info, options) if info.drivername == 'sqlite': connect_args = options.setdefault('connect_args', {}) if 'isolation_level' not in connect_args: # required to support savepoints/rollback without error. It disables # implicit BEGIN/COMMIT statements made by pysqlite (a COMMIT kills all # savepoints made). connect_args['isolation_level'] = None elif info.drivername.startswith('postgres'): options.setdefault('client_encoding', 'utf8')
if StrictVersion(FLASK_SA_VERSION) <= StrictVersion('1.0'): # SA extension's scoped session supports 'bind' parameter only after 1.0. This # is a fix for it. This is required to ensure transaction rollback during # tests, but it's useful in some use cases too. from flask_sqlalchemy import _SignallingSession as BaseSession class SignallingSession(BaseSession): def __init__(self, db, autocommit=False, autoflush=True, **options): self.app = db.get_app() self._model_changes = {} bind = options.pop('bind', None) or db.engine # actually we are overriding BaseSession.__init__, so we don't want to # call it! Directly call BaseSession parent __init__ sa.orm.Session.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, binds=db.get_binds(self.app), **options)
[docs] class SQLAlchemy(AbilianBaseSAExtension): def create_scoped_session(self, options=None): """Helper factory method that creates a scoped session.""" # override needed to use our SignallingSession implementation if options is None: options = {} scopefunc = options.pop('scopefunc', None) return sa.orm.scoped_session(partial(SignallingSession, self, **options), scopefunc=scopefunc)
else: # Flask-SQLAlchemy > 1.0: bind parameter is supported SQLAlchemy = AbilianBaseSAExtension del FLASK_SA_VERSION # PATCH flask_sqlalchemy for proper info in debug toolbar. # # Original code works only when current app code is involved. If using 3rd party # app the query is logged but source is marked "unknown". Our patch is a "best # guess". def _calling_context(app_path): frm = sys._getframe(1) entered_sa_code = exited_sa_code = False sa_caller = '<unknown>' format_name = ('{frm.f_code.co_filename}:{frm.f_lineno} ' '({frm.f_code.co_name})'.format) while frm.f_back is not None: name = frm.f_globals.get('__name__') if name and (name == app_path or name.startswith(app_path + '.')): return format_name(frm=frm) if not exited_sa_code: in_sa_code = name and (name == 'sqlalchemy' or name.startswith('sqlalchemy.')) if not entered_sa_code: entered_sa_code = in_sa_code elif not in_sa_code: # exited from sa stack: retain name sa_caller = format_name(frm=frm) exited_sa_code = True frm = frm.f_back return sa_caller import flask_sqlalchemy as flask_sa patch_logger.info(flask_sa._calling_context) flask_sa._calling_context = _calling_context del flask_sa # END PATCH
[docs]def filter_cols(model, *filtered_columns): """ Return columnsnames for a model except named ones. Useful for defer() for example to retain only columns of interest """ m = sa.orm.class_mapper(model) return list(set(p.key for p in m.iterate_properties if hasattr(p, 'columns')).difference(filtered_columns))
[docs]class MutationDict(Mutable, dict): """Provides a dictionary type with mutability support.""" @classmethod
[docs] def coerce(cls, key, value): """Convert plain dictionaries to MutationDict.""" if not isinstance(value, MutationDict): if isinstance(value, dict): return MutationDict(value) # this call will raise ValueError return Mutable.coerce(key, value) else: return value # pickling support. see: # http://docs.sqlalchemy.org/en/rel_0_8/orm/extensions/mutable.html#supporting-pickling
def __getstate__(self): return dict(self) def __setstate__(self, state): self.update(state) # dict methods def __setitem__(self, key, value): """Detect dictionary set events and emit change events.""" dict.__setitem__(self, key, value) self.changed() def __delitem__(self, key): """Detect dictionary del events and emit change events.""" dict.__delitem__(self, key) self.changed()
[docs] def clear(self): dict.clear(self) self.changed()
[docs] def update(self, other): dict.update(self, other) self.changed()
[docs] def setdefault(self, key, failobj=None): if key not in self.data: self.changed() return dict.setdefault(self, key, failobj)
[docs] def pop(self, key, *args): self.changed() return dict.pop(self, key, *args)
[docs] def popitem(self): self.changed() return dict.popitem(self)
[docs]class MutationList(Mutable, list): """ Provides a list type with mutability support. """ @classmethod
[docs] def coerce(cls, key, value): """Convert list to MutationList.""" if not isinstance(value, MutationList): if isinstance(value, list): return MutationList(value) # this call will raise ValueError return Mutable.coerce(key, value) else: return value # pickling support. see: # http://docs.sqlalchemy.org/en/rel_0_8/orm/extensions/mutable.html#supporting-pickling
def __getstate__(self): d = self.__dict__.copy() d.pop('_parents', None) return d # list methods def __setitem__(self, idx, value): list.__setitem__(self, idx, value) self.changed() def __delitem__(self, idx): list.__delitem__(self, idx) self.changed()
[docs] def insert(self, idx, value): list.insert(self, idx, value) self.changed()
def __setslice__(self, i, j, other): list.__setslice__(self, i, j, other) self.changed() def __delslice__(self, i, j): list.__delslice__(self, i, j) self.changed() def __iadd__(self, other): l = list.__iadd__(self, other) self.changed() return l def __imul__(self, n): l = list.__imul__(self, n) self.changed() return l
[docs] def append(self, item): list.append(self, item) self.changed()
[docs] def pop(self, i=-1): item = list.pop(self, i) self.changed() return item
[docs] def remove(self, item): list.remove(self, item) self.changed()
[docs] def reverse(self): list.reverse(self) self.changed()
[docs] def sort(self, *args, **kwargs): list.sort(self, *args, **kwargs) self.changed()
[docs] def extend(self, other): list.extend(self, other) self.changed()
[docs]class JSON(sa.types.TypeDecorator): """Stores any structure serializable with json. Usage JSON() Takes same parameters as sqlalchemy.types.Text """ impl = sa.types.Text
[docs] def process_bind_param(self, value, dialect): if value is not None: value = json.dumps(value) return value
[docs] def process_result_value(self, value, dialect): if value is not None: value = json.loads(value) return value
[docs]class JSONUniqueListType(JSON): """ Store a list in JSON format, with items made unique and sorted. """ @property def python_type(self): return MutationList
[docs] def process_bind_param(self, value, dialect): # value may be a simple string used in a LIKE clause for instance, so we # must ensure we uniquify/sort only for list-like values if value is not None and isinstance(value, (tuple, list)): value = sorted(set(value)) return JSON.process_bind_param(self, value, dialect)
[docs]def JSONDict(*args, **kwargs): """ Stores a dict as JSON on database, with mutability support. """ return MutationDict.as_mutable(JSON(*args, **kwargs))
[docs]def JSONList(*args, **kwargs): """ Stores a list as JSON on database, with mutability support. If kwargs has a param `unique_sorted` (which evaluated to True), list values are made unique and sorted. """ type_ = JSON try: if kwargs.pop('unique_sorted'): type_ = JSONUniqueListType except KeyError: pass return MutationList.as_mutable(type_(*args, **kwargs))
[docs]class UUID(sa.types.TypeDecorator): """ Platform-independent UUID type. Uses Postgresql's UUID type, otherwise uses CHAR(32), storing as stringified hex values. From SQLAlchemy documentation. """ impl = sa.types.CHAR
[docs] def load_dialect_impl(self, dialect): if dialect.name == 'postgresql': return dialect.type_descriptor(sa.dialects.postgresql.UUID()) else: return dialect.type_descriptor(sa.types.CHAR(32))
[docs] def process_bind_param(self, value, dialect): if value is None: return value elif dialect.name == 'postgresql': return str(value) else: if not isinstance(value, uuid.UUID): value = uuid.UUID(value) # hexstring return "%.32x" % value
[docs] def process_result_value(self, value, dialect): return value if value is None else uuid.UUID(value)
[docs]class Locale(sa.types.TypeDecorator): """ Store a :class:`babel.Locale` instance """ impl = sa.types.UnicodeText @property def python_type(self): return babel.Locale
[docs] def process_bind_param(self, value, dialect): if value is None: return None if not isinstance(value, babel.Locale): if not isinstance(value, basestring): raise ValueError("Unknown locale value: %s" % repr(value)) if not value.strip(): return None value = babel.Locale.parse(value) code = unicode(value.language) if value.territory: code += u'_' + unicode(value.territory) elif value.script: code += u'_' + unicode(value.territory) return code
[docs] def process_result_value(self, value, dialect): return None if value is None else babel.Locale.parse(value)
[docs]class Timezone(sa.types.TypeDecorator): """ Store a :class:`pytz.tzfile.DstTzInfo` instance """ impl = sa.types.UnicodeText @property def python_type(self): return pytz.tzfile.DstTzInfo
[docs] def process_bind_param(self, value, dialect): if value is None: return None if not isinstance(value, pytz.tzfile.DstTzInfo): if not isinstance(value, basestring): raise ValueError("Unknown timezone value: %s" % repr(value)) if not value.strip(): return None value = babel.dates.get_timezone(value) return value.zone
[docs] def process_result_value(self, value, dialect): return None if value is None else babel.dates.get_timezone(value)