Source code for flexget.db_schema

from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, Any

import sqlalchemy.event
from loguru import logger
from sqlalchemy import Column, DateTime, Integer, String
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import as_declarative

import flexget
from flexget.event import event
from flexget.manager import Base, Session
from flexget.utils.database import with_session
from flexget.utils.sqlalchemy_utils import table_schema
from flexget.utils.tools import get_current_flexget_version

if TYPE_CHECKING:
    from collections.abc import Callable

    from sqlalchemy import Table

logger = logger.bind(name='schema')

# Stores a mapping of {plugin: {'version': version, 'tables': ['table_names'])}
plugin_schemas: dict[str, dict[str, Any]] = {}


[docs] class FlexgetVersion(Base): __tablename__ = 'flexget_version' version = Column(String, primary_key=True) created = Column(DateTime, default=datetime.now) def __init__(self): self.version = get_current_flexget_version()
[docs] @event('manager.startup') def set_flexget_db_version(manager=None) -> None: with Session() as session: db_version = session.query(FlexgetVersion).first() if not db_version: logger.debug('entering flexget version {} to db', get_current_flexget_version()) session.add(FlexgetVersion()) elif db_version.version != get_current_flexget_version(): logger.debug('updating flexget version {} in db', get_current_flexget_version()) db_version.version = get_current_flexget_version() db_version.created = datetime.now() session.commit() else: logger.debug('current flexget version already exist in db {}', db_version.version)
[docs] def get_flexget_db_version() -> str | None: with Session() as session: version = session.query(FlexgetVersion).first() if version: return version.version return None
[docs] class PluginSchema(Base): __tablename__ = 'plugin_schema' id = Column(Integer, primary_key=True) plugin = Column(String) version = Column(Integer) def __init__(self, plugin: str, version: int = 0): self.plugin = plugin self.version = version def __str__(self) -> str: return f'<PluginSchema(plugin={self.plugin},version={self.version})>'
@with_session def get_version(plugin: str, session=None) -> int | None: schema = session.query(PluginSchema).filter(PluginSchema.plugin == plugin).first() if not schema: logger.debug('No schema version stored for {}', plugin) return None return schema.version @with_session def set_version(plugin: str, version: int, session=None) -> None: if plugin not in plugin_schemas: raise ValueError( f'Tried to set schema version for {plugin} plugin with no versioned_base.' ) base_version = plugin_schemas[plugin]['version'] if version != base_version: raise ValueError( f'Tried to set {plugin} plugin schema version to {version} when ' f'it should be {base_version} as defined in versioned_base.' ) schema = session.query(PluginSchema).filter(PluginSchema.plugin == plugin).first() if not schema: logger.debug('Initializing plugin {} schema version to {}', plugin, version) schema = PluginSchema(plugin, version) session.add(schema) else: if version < schema.version: raise ValueError(f'Tried to set plugin {plugin} schema version to lower value') if version != schema.version: logger.debug('Updating plugin {} schema version to {}', plugin, version) schema.version = version session.commit() @with_session def upgrade_required(session=None) -> bool: """Return true if an upgrade of the database is required.""" old_schemas = session.query(PluginSchema).all() if len(old_schemas) < len(plugin_schemas): return True for old_schema in old_schemas: if ( old_schema.plugin in plugin_schemas and old_schema.version < plugin_schemas[old_schema.plugin]['version'] ): return True return False
[docs] class UpgradeImpossible(Exception): """Exception to be thrown during a db upgrade function which will cause the old tables to be removed and recreated from the new model."""
[docs] def upgrade(plugin: str) -> Callable: """Use as a decorator to register a schema upgrade function. The wrapped function will be passed the current schema version and a session object. The function should return the new version of the schema after the upgrade. There is no need to commit the session, it will commit automatically if an upgraded schema version is returned. Example:: from flexget import schema @schema.upgrade('your_plugin') def upgrade(ver, session): if ver == 2: # upgrade ver = 3 return ver """ def upgrade_decorator(upgrade_func): @event('manager.upgrade') def upgrade_wrapper(manager): with Session() as session: current_ver = get_version(plugin, session=session) try: new_ver = upgrade_func(current_ver, session) except UpgradeImpossible: logger.info( 'Plugin {} database is not upgradable. Flushing data and regenerating.', plugin, ) reset_schema(plugin, session=session) manager.db_upgraded = True except Exception: logger.exception('Failed to upgrade database for plugin {}', plugin) session.rollback() manager.shutdown(finish_queue=False) else: current_ver = -1 if current_ver is None else current_ver if new_ver > current_ver: logger.info('Plugin `{}` schema upgraded successfully', plugin) set_version(plugin, new_ver, session=session) manager.db_upgraded = True elif new_ver < current_ver: logger.critical( 'A lower schema version was returned ({}) from plugin ' '{} upgrade function than passed in ({})', new_ver, plugin, current_ver, ) session.rollback() manager.shutdown(finish_queue=False) return upgrade_wrapper return upgrade_decorator
@with_session def reset_schema(plugin: str, session=None) -> None: """Remove all tables from given plugin from the database, as well as removing current stored schema number. :param plugin: The plugin whose schema should be reset """ if plugin not in plugin_schemas: raise ValueError(f'The plugin {plugin} has no stored schema to reset.') table_names = plugin_schemas[plugin].get('tables', []) tables = [table_schema(name, session) for name in table_names] # Remove the plugin's tables for table in tables: try: table.drop(bind=session.bind) except OperationalError as e: if 'no such table' in str(e): continue raise # Remove the plugin from schema table session.query(PluginSchema).filter(PluginSchema.plugin == plugin).delete() # We need to commit our current changes to close the session before calling create_all session.commit() # Create new empty tables Base.metadata.create_all(bind=session.bind)
[docs] def register_plugin_table(tablename: str, plugin: str, version: int): plugin_schemas.setdefault(plugin, {'version': version, 'tables': []}) if plugin_schemas[plugin]['version'] != version: raise RuntimeError(f'Two different schema versions received for plugin {plugin}') plugin_schemas[plugin]['tables'].append(tablename)
[docs] class VersionedBaseMeta(DeclarativeMeta): """Metaclass for objects returned by versioned_base factory.""" def __new__(cls, metaname, bases, dict_): """Get called when a class that subclasses VersionedBase is defined.""" new_class = super().__new__(cls, str(metaname), bases, dict_) if metaname != 'VersionedBase': register_plugin_table(new_class.__tablename__, new_class._plugin, new_class._version) return new_class
[docs] def register_table(self, table: str | Table) -> None: """Can be used if a plugin is declaring non-declarative sqlalchemy tables. :param table: Can either be the name of the table, or an :class:`sqlalchemy.Table` instance. """ if isinstance(table, str): register_plugin_table(table, self._plugin, self._version) else: register_plugin_table(table.name, self._plugin, self._version)
[docs] def versioned_base(plugin: str, version: int) -> VersionedBaseMeta: """Return a class which can be used like Base, but automatically stores schema version when tables are created.""" @as_declarative(metaclass=VersionedBaseMeta, metadata=Base.metadata) class VersionedBase: _plugin = plugin _version = version return VersionedBase
[docs] @sqlalchemy.event.listens_for(Base.metadata, 'after_create') def after_table_create(target, connection, tables: list[Table] | None = None, **kw) -> None: """Set the schema version to most recent for a plugin when it's tables are freshly created.""" if tables: # TODO: Detect if any database upgrading is needed and acquire the lock only in one place with flexget.manager.manager.acquire_lock(event=False): tables = [table.name for table in tables] for plugin, info in plugin_schemas.items(): # Only set the version if all tables for a given plugin are being created if all(table in tables for table in info['tables']): set_version(plugin, info['version'])