diff --git a/postgresql_audit/alembic/__init__.py b/postgresql_audit/alembic/__init__.py new file mode 100644 index 0000000..e0509da --- /dev/null +++ b/postgresql_audit/alembic/__init__.py @@ -0,0 +1,129 @@ +import re +from itertools import groupby + +from alembic.autogenerate import comparators, rewriter +from alembic.operations import ops + +from postgresql_audit.alembic.init_activity_table_triggers import InitActivityTableTriggersOp, \ + RemoveActivityTableTriggersOp +from postgresql_audit.alembic.migration_ops import AddColumnToActivityOp, RemoveColumnFromRemoveActivityOp +from postgresql_audit.alembic.register_table_for_version_tracking import RegisterTableForVersionTrackingOp, \ + DeregisterTableForVersionTrackingOp + + +@comparators.dispatch_for("schema") +def compare_timestamp_schema(autogen_context, upgrade_ops, schemas): + routines = set() + for sch in schemas: + schema_name = autogen_context.dialect.default_schema_name if sch is None else sch + routines.update([ + (sch, *row) for row in autogen_context.connection.execute( + "select routine_name, routine_definition from information_schema.routines " + f"where routines.specific_schema='{schema_name}' " + )]) + + for sch in schemas: + should_track_versions = any("versioned" in table.info for table in autogen_context.sorted_tables if table.info and table.schema == sch) + schema_prefix = f"{sch}." if sch else "" + + a = next((v for k, v in groupby(routines, key=lambda x: x[0]) if k == sch), None) + a = list(a) if a else [] + if should_track_versions: + if f"{schema_prefix}audit_table" not in (x[1] for x in a): + upgrade_ops.ops.insert(0, + InitActivityTableTriggersOp(False, schema=sch) + ) + else: + if f"{schema_prefix}audit_table" in (x[1] for x in a): + upgrade_ops.ops.append( + RemoveActivityTableTriggersOp(False, schema=sch) + ) + + +@comparators.dispatch_for("table") +def compare_timestamp_table(autogen_context, modify_ops, schemaname, tablename, conn_table, metadata_table): + if metadata_table is None: + return + meta_info = metadata_table.info or {} + schema_name = autogen_context.dialect.default_schema_name if schemaname is None else schemaname + + triggers = [row for row in autogen_context.connection.execute(f""" +select event_object_schema as table_schema, + event_object_table as table_name, + trigger_schema, + trigger_name, + string_agg(event_manipulation, ',') as event, + action_timing as activation, + action_condition as condition, + action_statement as definition +from information_schema.triggers +where event_object_table = '{tablename}' and trigger_schema = '{schema_name}' +group by 1,2,3,4,6,7,8 +order by table_schema, table_name; + """)] + + trigger_name = "audit_trigger" + + if "versioned" in meta_info: + excluded_columns = metadata_table.info["versioned"].get("exclude", tuple()) + trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None) + original_excluded_columns = __get_existing_excluded_columns(trigger) + + if trigger and set(original_excluded_columns) == set(excluded_columns): + return + + modify_ops.ops.insert(0, + RegisterTableForVersionTrackingOp(tablename, excluded_columns, original_excluded_columns, schema=schema_name) + ) + else: + trigger = next((trigger for trigger in triggers if trigger_name in trigger[3]), None) + original_excluded_columns = __get_existing_excluded_columns(trigger) + + if trigger: + modify_ops.ops.append( + DeregisterTableForVersionTrackingOp(tablename, original_excluded_columns, schema=schema_name) + ) + + +def __get_existing_excluded_columns(trigger): + original_excluded_columns = () + if trigger: + arguments_match = re.search(r"EXECUTE FUNCTION create_activity\('{(.+)}'\)", trigger[7]) + if arguments_match: + original_excluded_columns = arguments_match.group(1).split(",") + return original_excluded_columns + + +writer = rewriter.Rewriter() + +@writer.rewrites(ops.AddColumnOp) +def add_column_rewrite(context, revision, op): + table_info = op.column.table.info or {} + if "versioned" in table_info and op.column.name not in table_info["versioned"].get("exclude", []): + return [ + op, + AddColumnToActivityOp( + op.table_name, + op.column.name, + schema=op.column.table.schema, + ), + ] + else: + return op + +@writer.rewrites(ops.DropColumnOp) +def drop_column_rewrite(context, revision, op): + column = op._orig_column + table_info = column.table.info or {} + if "versioned" in table_info and column.name not in table_info["versioned"].get("exclude", []): + return [ + op, + RemoveColumnFromRemoveActivityOp( + op.table_name, + column.name, + schema=column.table.schema, + ), + ] + else: + return op + diff --git a/postgresql_audit/alembic/init_activity_table_triggers.py b/postgresql_audit/alembic/init_activity_table_triggers.py new file mode 100644 index 0000000..a40364a --- /dev/null +++ b/postgresql_audit/alembic/init_activity_table_triggers.py @@ -0,0 +1,98 @@ +from alembic.autogenerate import renderers +from alembic.operations import Operations, MigrateOperation + +from postgresql_audit.utils import render_tmpl, create_audit_table, create_operators + + +@Operations.register_operation("init_activity_table_triggers") +class InitActivityTableTriggersOp(MigrateOperation): + """Initialize Activity Table Triggers""" + + def __init__(self, use_statement_level_triggers, schema=None): + self.schema = schema + self.use_statement_level_triggers = use_statement_level_triggers + + @classmethod + def init_activity_table_triggers(cls, operations, use_statement_level_triggers, **kwargs): + op = InitActivityTableTriggersOp(use_statement_level_triggers, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RemoveActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema) + +@Operations.register_operation("remove_activity_table_triggers") +class RemoveActivityTableTriggersOp(MigrateOperation): + """Drop Activity Table Triggers""" + + def __init__(self, use_statement_level_triggers, schema=None): + self.schema = schema + self.use_statement_level_triggers = use_statement_level_triggers + + + @classmethod + def remove_activity_table_triggers(cls, operations, **kwargs): + op = RemoveActivityTableTriggersOp(False, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return InitActivityTableTriggersOp(self.use_statement_level_triggers, schema=self.schema) + + +@Operations.implementation_for(InitActivityTableTriggersOp) +def init_activity_table_triggers(operations, operation): + conn = operations + bind = conn.get_bind() + + if operation.schema: + conn.execute(render_tmpl('create_schema.sql', operation.schema)) + + conn.execute(render_tmpl('jsonb_change_key_name.sql', operation.schema)) + create_audit_table(None, bind, operation.schema, operation.use_statement_level_triggers) + create_operators(None, bind, operation.schema) + + +@Operations.implementation_for(RemoveActivityTableTriggersOp) +def remove_activity_table_triggers(operations, operation): + conn = operations + bind = conn.get_bind() + + if operation.schema: + conn.execute(render_tmpl('drop_schema.sql', operation.schema)) + + conn.execute("DROP FUNCTION jsonb_change_key_name(data jsonb, old_key text, new_key text)") + schema_prefix = f"{operation.schema}." if operation.schema else "" + + conn.execute(f"DROP FUNCTION {schema_prefix}audit_table(target_table regclass, ignored_cols text[])") + conn.execute(f"DROP FUNCTION {schema_prefix}create_activity()") + + + if bind.dialect.server_version_info < (9, 5, 0): + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT)""") + conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text);""") + conn.execute(f"""DROP FUNCTION jsonb_merge(jsonb, jsonb)""") + conn.execute(f"""DROP OPERATOR IF EXISTS || (jsonb, jsonb);""") + if bind.dialect.server_version_info < (9, 6, 0): + conn.execute(f"""DROP FUNCTION current_setting(TEXT, BOOL)""") + if bind.dialect.server_version_info < (10, 0): + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb, TEXT[])""") + conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, text[])""") + + conn.execute(f"""DROP OPERATOR IF EXISTS - (jsonb, jsonb)""") + conn.execute(f"""DROP FUNCTION get_setting(text, text)""") + conn.execute(f"""DROP FUNCTION jsonb_subtract(jsonb,jsonb)""") + + +@renderers.dispatch_for(InitActivityTableTriggersOp) +def render_init_activity_table_triggers(autogen_context, op): + return "op.init_activity_table_triggers(%r, **%r)" % ( + op.use_statement_level_triggers, + {"schema": op.schema} + ) + +@renderers.dispatch_for(RemoveActivityTableTriggersOp) +def render_remove_activity_table_triggers(autogen_context, op): + return "op.remove_activity_table_triggers(**%r)" % ( + {"schema": op.schema} + ) diff --git a/postgresql_audit/alembic/migration_ops.py b/postgresql_audit/alembic/migration_ops.py new file mode 100644 index 0000000..f22d091 --- /dev/null +++ b/postgresql_audit/alembic/migration_ops.py @@ -0,0 +1,70 @@ +from alembic.autogenerate import renderers +from alembic.operations import Operations, MigrateOperation + +from postgresql_audit import add_column, remove_column + + +@Operations.register_operation("add_column_to_activity") +class AddColumnToActivityOp(MigrateOperation): + """Initialize Activity Table Triggers""" + + def __init__(self, table_name, column_name, default_value=None, schema=None): + self.schema = schema + self.table_name = table_name + self.column_name = column_name + self.default_value = default_value + + @classmethod + def add_column_to_activity(cls, operations, table_name, column_name, **kwargs): + op = AddColumnToActivityOp(table_name, column_name, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RemoveColumnFromRemoveActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema) + +@Operations.register_operation("remove_column_from_activity") +class RemoveColumnFromRemoveActivityOp(MigrateOperation): + """Drop Activity Table Triggers""" + + def __init__(self, table_name, column_name, default_value=None, schema=None): + self.schema = schema + self.table_name = table_name + self.column_name = column_name + self.default_value = default_value + + @classmethod + def remove_column_from_activity(cls, operations, table_name, column_name, **kwargs): + op = RemoveColumnFromRemoveActivityOp(table_name, column_name, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return AddColumnToActivityOp(self.table_name, self.column_name, default_value=self.default_value, schema=self.schema) + + +@Operations.implementation_for(AddColumnToActivityOp) +def add_column_to_activity(operations, operation): + add_column(operations, operation.table_name, operation.column_name, default_value=operation.default_value, schema=operation.schema) + + +@Operations.implementation_for(RemoveColumnFromRemoveActivityOp) +def remove_column_from_activity(operations, operation): + conn = operations.connection + remove_column(conn, operation.table_name, operation.column_name, operation.schema) + +@renderers.dispatch_for(AddColumnToActivityOp) +def render_add_column_to_activity(autogen_context, op): + return "op.add_column_to_activity(%r, %r, **%r)" % ( + op.table_name, + op.column_name, + {"schema": op.schema, "default_value": op.default_value} + ) + +@renderers.dispatch_for(RemoveColumnFromRemoveActivityOp) +def render_remove_column_from_activitys(autogen_context, op): + return "op.remove_column_from_activity(%r, %r, **%r)" % ( + op.table_name, + op.column_name, + {"schema": op.schema} + ) diff --git a/postgresql_audit/alembic/register_table_for_version_tracking.py b/postgresql_audit/alembic/register_table_for_version_tracking.py new file mode 100644 index 0000000..70a158a --- /dev/null +++ b/postgresql_audit/alembic/register_table_for_version_tracking.py @@ -0,0 +1,76 @@ +import sqlalchemy as sa + +from alembic.autogenerate import renderers +from alembic.operations import Operations, MigrateOperation + + +@Operations.register_operation("register_for_version_tracking") +class RegisterTableForVersionTrackingOp(MigrateOperation): + """Register Table for Version Tracking""" + + def __init__(self, tablename, excluded_columns, original_excluded_columns=None, schema=None): + self.schema = schema + self.tablename = tablename + self.excluded_columns = excluded_columns + self.original_excluded_columns = original_excluded_columns + + @classmethod + def register_for_version_tracking(cls, operations, tablename, exclude_columns, **kwargs): + op = RegisterTableForVersionTrackingOp(tablename, exclude_columns, **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return DeregisterTableForVersionTrackingOp(self.tablename, self.original_excluded_columns, schema=self.schema) + +@Operations.register_operation("deregister_for_version_tracking") +class DeregisterTableForVersionTrackingOp(MigrateOperation): + """Drop Table from Version Tracking""" + + def __init__(self, tablename, excluded_columns, schema=None): + self.schema = schema + self.tablename = tablename + self.excluded_columns = excluded_columns + + + @classmethod + def deregister_for_version_tracking(cls, operations, tablename, **kwargs): + op = DeregisterTableForVersionTrackingOp(tablename, (), **kwargs) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return RegisterTableForVersionTrackingOp(self.tablename, self.excluded_columns, (), schema=self.schema) + + +@Operations.implementation_for(RegisterTableForVersionTrackingOp) +def register_for_version_tracking(operations, operation): + if operation.schema is None: + func = sa.func.audit_table + else: + func = getattr(getattr(sa.func, operation.schema), 'audit_table') + operations.execute(sa.select([func(operation.tablename, list(operation.excluded_columns))])) + + +@Operations.implementation_for(DeregisterTableForVersionTrackingOp) +def deregister_for_version_tracking(operations, operation): + operations.execute(f"drop trigger if exists audit_trigger_insert on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_update on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_delete on {operation.tablename} ") + operations.execute(f"drop trigger if exists audit_trigger_row on {operation.tablename} ") + + +@renderers.dispatch_for(RegisterTableForVersionTrackingOp) +def render_register_for_version_tracking(autogen_context, op): + return "op.register_for_version_tracking(%r, %r, **%r)" % ( + op.tablename, + op.excluded_columns, + {"schema": op.schema} + ) + +@renderers.dispatch_for(DeregisterTableForVersionTrackingOp) +def render_deregister_for_version_tracking(autogen_context, op): + return "op.deregister_for_version_tracking(%r, **%r)" % ( + op.tablename, + {"schema": op.schema} + ) diff --git a/postgresql_audit/base.py b/postgresql_audit/base.py index e79133a..d69715d 100644 --- a/postgresql_audit/base.py +++ b/postgresql_audit/base.py @@ -1,7 +1,7 @@ -import os -import string import warnings +from collections import Sequence from contextlib import contextmanager +from functools import partial from weakref import WeakSet import sqlalchemy as sa @@ -18,7 +18,9 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import get_class_by_table -HERE = os.path.dirname(os.path.abspath(__file__)) +from postgresql_audit.utils import render_tmpl, StatementExecutor, create_audit_table, create_operators, \ + build_register_table_query + cached_statements = {} @@ -29,23 +31,6 @@ class ImproperlyConfigured(Exception): class ClassNotVersioned(Exception): pass - -class StatementExecutor(object): - def __init__(self, stmt): - self.stmt = stmt - - def __call__(self, target, bind, **kwargs): - tx = bind.begin() - bind.execute(self.stmt) - tx.commit() - - -def read_file(file_): - with open(os.path.join(HERE, file_)) as f: - s = f.read() - return s - - def assign_actor(base, cls, actor_cls): if hasattr(cls, 'actor_id'): return @@ -164,148 +149,22 @@ def convert_callables(values): } -class VersioningManager(object): - _actor_cls = None - - def __init__( - self, - actor_cls=None, - schema_name=None, - use_statement_level_triggers=True - ): - if actor_cls is not None: - self._actor_cls = actor_cls - self.values = {} +class SessionManager(object): + def __init__(self, transaction_cls, values=None): + self.transaction_cls = transaction_cls + self.values = values or {} + self._marked_transactions = set() self.listeners = ( - ( - orm.mapper, - 'instrument_class', - self.instrument_versioned_classes - ), - ( - orm.mapper, - 'after_configured', - self.configure_versioned_classes - ), ( orm.session.Session, 'before_flush', - self.receive_before_flush, + self.before_flush, ), ) - self.schema_name = schema_name - self.table_listeners = self.get_table_listeners() - self.pending_classes = WeakSet() - self.cached_ddls = {} - self.use_statement_level_triggers = use_statement_level_triggers def get_transaction_values(self): return self.values - @contextmanager - def disable(self, session): - session.execute( - "SET LOCAL postgresql_audit.enable_versioning = 'false'" - ) - try: - yield - finally: - session.execute( - "SET LOCAL postgresql_audit.enable_versioning = 'true'" - ) - - def render_tmpl(self, tmpl_name): - file_contents = read_file( - 'templates/{}'.format(tmpl_name) - ).replace('%', '%%').replace('$$', '$$$$') - tmpl = string.Template(file_contents) - context = dict(schema_name=self.schema_name) - - if self.schema_name is None: - context['schema_prefix'] = '' - context['revoke_cmd'] = '' - else: - context['schema_prefix'] = '{}.'.format(self.schema_name) - context['revoke_cmd'] = ( - 'REVOKE ALL ON {schema_prefix}activity FROM public;' - ).format(**context) - - temp = tmpl.substitute(**context) - return temp - - def create_operators(self, target, bind, **kwargs): - if bind.dialect.server_version_info < (9, 5, 0): - StatementExecutor(self.render_tmpl('operators_pre95.sql'))( - target, bind, **kwargs - ) - if bind.dialect.server_version_info < (9, 6, 0): - StatementExecutor(self.render_tmpl('operators_pre96.sql'))( - target, bind, **kwargs - ) - if bind.dialect.server_version_info < (10, 0): - operators_template = self.render_tmpl('operators_pre100.sql') - StatementExecutor(operators_template)(target, bind, **kwargs) - operators_template = self.render_tmpl('operators.sql') - StatementExecutor(operators_template)(target, bind, **kwargs) - - def create_audit_table(self, target, bind, **kwargs): - sql = '' - if ( - self.use_statement_level_triggers and - bind.dialect.server_version_info >= (10, 0) - ): - sql += self.render_tmpl('create_activity_stmt_level.sql') - sql += self.render_tmpl('audit_table_stmt_level.sql') - else: - sql += self.render_tmpl('create_activity_row_level.sql') - sql += self.render_tmpl('audit_table_row_level.sql') - StatementExecutor(sql)(target, bind, **kwargs) - - def get_table_listeners(self): - listeners = {'transaction': []} - - listeners['activity'] = [ - ('after_create', sa.schema.DDL( - self.render_tmpl('jsonb_change_key_name.sql') - )), - ('after_create', self.create_audit_table), - ('after_create', self.create_operators) - ] - if self.schema_name is not None: - listeners['transaction'] = [ - ('before_create', sa.schema.DDL( - self.render_tmpl('create_schema.sql') - )), - ('after_drop', sa.schema.DDL( - self.render_tmpl('drop_schema.sql') - )), - ] - return listeners - - def audit_table(self, table, exclude_columns=None): - args = [table.name] - if exclude_columns: - for column in exclude_columns: - if column not in table.c: - raise ImproperlyConfigured( - "Could not configure versioning. Table '{}'' does " - "not have a column named '{}'.".format( - table.name, column - ) - ) - args.append(array(exclude_columns)) - - if self.schema_name is None: - func = sa.func.audit_table - else: - func = getattr(getattr(sa.func, self.schema_name), 'audit_table') - query = sa.select([func(*args)]) - if query not in cached_statements: - cached_statements[query] = StatementExecutor(query) - listener = (table, 'after_create', cached_statements[query]) - if not sa.event.contains(*listener): - sa.event.listen(*listener) - def set_activity_values(self, session): dialect = session.bind.engine.dialect table = self.transaction_cls.__table__ @@ -349,9 +208,10 @@ def modified_columns(self, obj): def is_modified(self, obj_or_session): if hasattr(obj_or_session, '__mapper__'): - if not hasattr(obj_or_session, '__versioned__'): + version_info = self.__get_versioned_info(obj_or_session) + if not version_info: raise ClassNotVersioned(obj_or_session.__class__.__name__) - excluded = obj_or_session.__versioned__.get('exclude', []) + excluded = version_info.get('exclude', []) return bool( set([ column.name @@ -362,43 +222,66 @@ def is_modified(self, obj_or_session): return any( self.is_modified(entity) or entity in obj_or_session.deleted for entity in obj_or_session - if hasattr(entity, '__versioned__') + if self.__get_versioned_info(entity) ) - def receive_before_flush(self, session, flush_context, instances): + def __get_versioned_info(self, entity): + v_args = getattr(entity, '__versioned__', None) + if v_args: + return v_args + table_args = getattr(entity, '__table_args__', None) + if not table_args: + return None + if isinstance(table_args, Sequence): + table_args = next((x for x in iter(table_args) if isinstance(x, dict)), None) + if not table_args: + return None + return table_args.get("info", {}).get("versioned", None) + + def before_flush(self, session, flush_context, instances): + if session.transaction in self._marked_transactions: + return + if session.transaction: + self.add_entry_and_mark_transaction(session) + + def add_entry_and_mark_transaction(self, session): if self.is_modified(session): + self._marked_transactions.add(session.transaction) self.set_activity_values(session) - def instrument_versioned_classes(self, mapper, cls): - """ - Collect versioned class and add it to pending_classes list. - - :mapper mapper: SQLAlchemy mapper object - :cls cls: SQLAlchemy declarative class - """ - if hasattr(cls, '__versioned__') and cls not in self.pending_classes: - self.pending_classes.add(cls) + def attach_listeners(self): + for listener in self.listeners: + sa.event.listen(*listener) - def configure_versioned_classes(self): - """ - Configures all versioned classes that were collected during - instrumentation process. - """ - for cls in self.pending_classes: - self.audit_table(cls.__table__, cls.__versioned__.get('exclude')) - assign_actor(self.base, self.transaction_cls, self.actor_cls) + def remove_listeners(self): + for listener in self.listeners: + sa.event.remove(*listener) - def attach_table_listeners(self): - for values in self.table_listeners['transaction']: - sa.event.listen(self.transaction_cls.__table__, *values) - for values in self.table_listeners['activity']: - sa.event.listen(self.activity_cls.__table__, *values) +class BasicVersioningManager(object): + _actor_cls = None + _session_manager_factory = partial(SessionManager, values={}) - def remove_table_listeners(self): - for values in self.table_listeners['transaction']: - sa.event.remove(self.transaction_cls.__table__, *values) - for values in self.table_listeners['activity']: - sa.event.remove(self.activity_cls.__table__, *values) + def __init__( + self, + actor_cls=None, + session_manager_factory=None, + schema_name=None, + use_statement_level_triggers=True + ): + if actor_cls is not None: + self._actor_cls = actor_cls + if session_manager_factory is not None: + self._session_manager_factory = session_manager_factory + self.values = {} + self.listeners = ( + ( + orm.mapper, + 'after_configured', + self.after_configured + ), + ) + self.schema_name = schema_name + self.use_statement_level_triggers = use_statement_level_triggers @property def actor_cls(self): @@ -424,15 +307,8 @@ def actor_cls(self): ) return self._actor_cls - def attach_listeners(self): - self.attach_table_listeners() - for listener in self.listeners: - sa.event.listen(*listener) - - def remove_listeners(self): - self.remove_table_listeners() - for listener in self.listeners: - sa.event.remove(*listener) + def after_configured(self): + assign_actor(self.base, self.transaction_cls, self.actor_cls) def activity_model_factory(self, base, transaction_cls): class Activity(activity_base(base, self.schema_name, transaction_cls)): @@ -446,6 +322,28 @@ class Transaction(transaction_base(base, self.schema_name)): return Transaction + def attach_listeners(self): + for listener in self.listeners: + sa.event.listen(*listener) + self.session_manager.attach_listeners() + + def remove_listeners(self): + for listener in self.listeners: + sa.event.remove(*listener) + self.session_manager.remove_listeners() + + @contextmanager + def disable(self, session): + session.execute( + "SET LOCAL postgresql_audit.enable_versioning = 'false'" + ) + try: + yield + finally: + session.execute( + "SET LOCAL postgresql_audit.enable_versioning = 'true'" + ) + def init(self, base): self.base = base self.transaction_cls = self.transaction_model_factory(base) @@ -453,7 +351,123 @@ def init(self, base): base, self.transaction_cls ) + self.session_manager = self._session_manager_factory(self.transaction_cls) self.attach_listeners() +class VersioningManager(BasicVersioningManager): + def __init__( + self, + actor_cls=None, + session_manager_factory=None, + schema_name=None, + use_statement_level_triggers=True + ): + super().__init__( + actor_cls=actor_cls, + schema_name=schema_name, + use_statement_level_triggers=use_statement_level_triggers, + session_manager_factory=session_manager_factory + ) + self.listeners = ( + ( + orm.mapper, + 'instrument_class', + self.instrument_versioned_classes + ), + ( + orm.mapper, + 'after_configured', + self.configure_versioned_classes + ), + ) + self.table_listeners = self.get_table_listeners() + self.pending_classes = WeakSet() + self.cached_ddls = {} + + def get_table_listeners(self): + listeners = {'transaction': []} + + listeners['activity'] = [ + ('after_create', sa.schema.DDL( + render_tmpl('jsonb_change_key_name.sql', self.schema_name) + )), + ('after_create', partial( + create_audit_table, + schema_name=self.schema_name, + use_statement_level_triggers=self.use_statement_level_triggers + ) + ), + ('after_create', partial(create_operators, schema_name=self.schema_name)) + ] + if self.schema_name is not None: + listeners['transaction'] = [ + ('before_create', sa.schema.DDL( + render_tmpl('create_schema.sql', self.schema_name) + )), + ('after_drop', sa.schema.DDL( + render_tmpl('drop_schema.sql', self.schema_name) + )), + ] + return listeners + + def audit_table(self, table, exclude_columns=None): + args = [table.name] + if exclude_columns: + for column in exclude_columns: + if column not in table.c: + raise ImproperlyConfigured( + "Could not configure versioning. Table '{}'' does " + "not have a column named '{}'.".format( + table.name, column + ) + ) + args.append(array(exclude_columns)) + query = build_register_table_query(self.schema_name, *args) + if query not in cached_statements: + cached_statements[query] = StatementExecutor(query) + listener = (table, 'after_create', cached_statements[query]) + if not sa.event.contains(*listener): + sa.event.listen(*listener) + + def instrument_versioned_classes(self, mapper, cls): + """ + Collect versioned class and add it to pending_classes list. + + :mapper mapper: SQLAlchemy mapper object + :cls cls: SQLAlchemy declarative class + """ + if hasattr(cls, '__versioned__') and cls not in self.pending_classes: + self.pending_classes.add(cls) + + def configure_versioned_classes(self): + """ + Configures all versioned classes that were collected during + instrumentation process. + """ + for cls in self.pending_classes: + self.audit_table(cls.__table__, cls.__versioned__.get('exclude')) + assign_actor(self.base, self.transaction_cls, self.actor_cls) + + def attach_table_listeners(self): + for values in self.table_listeners['transaction']: + sa.event.listen(self.transaction_cls.__table__, *values) + for values in self.table_listeners['activity']: + sa.event.listen(self.activity_cls.__table__, *values) + + def remove_table_listeners(self): + for values in self.table_listeners['transaction']: + sa.event.remove(self.transaction_cls.__table__, *values) + for values in self.table_listeners['activity']: + sa.event.remove(self.activity_cls.__table__, *values) + + def attach_listeners(self): + self.attach_table_listeners() + super().attach_listeners() + + def remove_listeners(self): + self.remove_table_listeners() + super().remove_listeners() + + versioning_manager = VersioningManager() diff --git a/postgresql_audit/flask.py b/postgresql_audit/flask.py index 8c788ed..049b68e 100644 --- a/postgresql_audit/flask.py +++ b/postgresql_audit/flask.py @@ -6,12 +6,10 @@ from flask import g, request from flask.globals import _app_ctx_stack, _request_ctx_stack -from .base import VersioningManager as BaseVersioningManager +from .base import VersioningManager, SessionManager -class VersioningManager(BaseVersioningManager): - _actor_cls = 'User' - +class FlaskSessionManager(SessionManager): def get_transaction_values(self): values = copy(self.values) if context_available() and hasattr(g, 'activity_values'): @@ -65,4 +63,4 @@ def activity_values(**values): del g.activity_values -versioning_manager = VersioningManager() +versioning_manager = VersioningManager(actor_cls="User", session_manager_factory=FlaskSessionManager) diff --git a/postgresql_audit/utils.py b/postgresql_audit/utils.py new file mode 100644 index 0000000..127e694 --- /dev/null +++ b/postgresql_audit/utils.py @@ -0,0 +1,75 @@ +import os +import string +import sqlalchemy as sa + +HERE = os.path.dirname(os.path.abspath(__file__)) + + +class StatementExecutor(object): + def __init__(self, stmt): + self.stmt = stmt + + def __call__(self, target, bind, **kwargs): + tx = bind.begin() + bind.execute(self.stmt) + tx.commit() + +def read_file(file_): + with open(os.path.join(HERE, file_)) as f: + s = f.read() + return s + +def render_tmpl(tmpl_name, schema_name=None): + file_contents = read_file( + 'templates/{}'.format(tmpl_name) + ).replace('%', '%%').replace('$$', '$$$$') + tmpl = string.Template(file_contents) + context = dict(schema_name=schema_name) + + if schema_name is None: + context['schema_prefix'] = '' + context['revoke_cmd'] = '' + else: + context['schema_prefix'] = '{}.'.format(schema_name) + context['revoke_cmd'] = ( + 'REVOKE ALL ON {schema_prefix}activity FROM public;' + ).format(**context) + + return tmpl.substitute(**context) + + +def create_operators(target, bind, schema_name, **kwargs): + if bind.dialect.server_version_info < (9, 5, 0): + StatementExecutor(render_tmpl('operators_pre95.sql', schema_name))( + target, bind, **kwargs + ) + if bind.dialect.server_version_info < (9, 6, 0): + StatementExecutor(render_tmpl('operators_pre96.sql', schema_name))( + target, bind, **kwargs + ) + if bind.dialect.server_version_info < (10, 0): + operators_template = render_tmpl('operators_pre100.sql', schema_name) + StatementExecutor(operators_template)(target, bind, **kwargs) + operators_template = render_tmpl('operators.sql', schema_name) + StatementExecutor(operators_template)(target, bind, **kwargs) + +def create_audit_table(target, bind, schema_name, use_statement_level_triggers, **kwargs): + sql = '' + if ( + use_statement_level_triggers and + bind.dialect.server_version_info >= (10, 0) + ): + sql += render_tmpl('create_activity_stmt_level.sql', schema_name) + sql += render_tmpl('audit_table_stmt_level.sql', schema_name) + else: + sql += render_tmpl('create_activity_row_level.sql', schema_name) + sql += render_tmpl('audit_table_row_level.sql', schema_name) + StatementExecutor(sql)(target, bind, **kwargs) + + +def build_register_table_query(schema_name, *args): + if schema_name is None: + func = sa.func.audit_table + else: + func = getattr(getattr(sa.func, schema_name), 'audit_table') + return sa.select([func(*args)])