from __future__ import annotations
import copy
import pickle
from datetime import datetime, timedelta
from functools import partial
from typing import TYPE_CHECKING
from loguru import logger
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Unicode, select
from sqlalchemy.orm import relationship
from flexget import db_schema
from flexget.entry import Entry
from flexget.event import event
from flexget.manager import Session
from flexget.plugin import PluginError
from flexget.utils import json, serialization
from flexget.utils.database import entry_synonym
from flexget.utils.sqlalchemy_utils import table_add_column, table_schema
from flexget.utils.tools import TimedDict, get_config_hash, parse_timedelta
logger = logger.bind(name='input_cache')
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from sqlalchemy.orm import Session as DBSession
class Base:
def __init__(self, *args, **kwargs) -> None: ...
else:
Base = db_schema.versioned_base('input_cache', 2)
@db_schema.upgrade('input_cache')
def upgrade(ver: int, session: DBSession) -> int:
if ver == 0:
table = table_schema('input_cache_entry', session)
table_add_column(table, 'json', Unicode, session)
# Make sure we get the new schema with the added column
table = table_schema('input_cache_entry', session)
for row in session.execute(select(table.c.id, table.c.entry)):
try:
p = pickle.loads(row['entry'])
session.execute(
table
.update()
.where(table.c.id == row['id'])
.values(json=json.dumps(p, encode_datetime=True))
)
except KeyError as ex:
logger.error('Unable error upgrading input_cache pickle object due to {}', ex)
ver = 1
if ver == 1:
table = table_schema('input_cache_entry', session)
for row in session.execute(select(table.c.id, table.c.json)):
if not row['json']:
# Seems there could be invalid data somehow. See #2590
continue
data = json.loads(row['json'], decode_datetime=True)
# If title looked like a date, make sure it's a string
# Had a weird case of an entry without a title: https://github.com/Flexget/Flexget/issues/2636
title = data.pop('title', None)
entry = partial(Entry, **data)
e = entry(title=str(title)) if title else entry()
session.execute(
table.update().where(table.c.id == row['id']).values(json=serialization.dumps(e))
)
ver = 2
return ver
[docs]
class InputCacheEntry(Base):
__tablename__ = 'input_cache_entry'
id = Column(Integer, primary_key=True)
_json = Column('json', Unicode)
entry = entry_synonym('_json')
cache_id = Column(Integer, ForeignKey('input_cache.id'), nullable=False)
[docs]
@event('manager.db_cleanup')
def db_cleanup(manager, session: DBSession) -> None:
"""Remove old input caches from plugins that are no longer configured."""
result = (
session
.query(InputCache)
.filter(InputCache.added < datetime.now() - timedelta(days=7))
.delete()
)
if result:
logger.verbose('Removed {} old input caches.', result)
[docs]
class cached: # noqa: N801 It acts like a function in usage
"""Implements transparent caching decorator @cached for inputs.
Decorator has two parameters:
* **name** in which the configuration is present in tasks configuration.
* **key** in which the configuration has the cached resource identifier (ie. url).
If the key is not given or present in the configuration :name: is expected to be a cache name (ie. url)
.. note:: Configuration assumptions may make this unusable in some (future) inputs
"""
cache = TimedDict(cache_time='5 minutes')
def __init__(self, name: str, persist: str | None = None) -> None:
# Cast name to unicode to prevent sqlalchemy warnings when filtering
self.name = str(name)
# Parse persist time
self.persist: timedelta | None = parse_timedelta(persist) if persist else None
# Will be set when wrapped function is called
self.config_hash = None
self.cache_name = None
def __call__(self, func):
def wrapped_func(*args, **kwargs):
# get task from method parameters
task = args[1]
self.config_hash = get_config_hash(args[2])
logger.trace('self.name: {}', self.name)
logger.trace('hash: {}', self.config_hash)
self.cache_name = self.name + '_' + self.config_hash
logger.debug(
'cache name: {} (has: {})', self.cache_name, ', '.join(list(self.cache.keys()))
)
if not task.options.nocache:
cache_value = self.cache.get(self.cache_name, None)
if cache_value:
# return from the cache
logger.verbose('Restored entries from cache')
return cache_value
if self.persist:
# Check database cache
db_cache = self.load_from_db()
if db_cache is not None:
return db_cache
# Nothing was restored from db or memory cache, run the function
logger.trace('cache miss')
# call input event
try:
response = func(*args, **kwargs) or []
except PluginError as e:
# If there was an error producing entries, but we have valid entries in the db cache, return those.
if self.persist and not task.options.nocache:
cache = self.load_from_db(load_expired=True)
if cache is not None:
logger.error(
'There was an error during {} input ({}), using cache instead.',
self.name,
e,
)
return cache
# If there was nothing in the db cache, re-raise the error.
raise
# store results to cache
logger.debug('storing entries to cache {} ', self.cache_name)
cache = IterableCache(response, self.store_to_db if self.persist else None)
self.cache[self.cache_name] = cache
return cache
return wrapped_func
[docs]
def store_to_db(self, entries: list[str]):
# Store to database
logger.debug('Storing cache {} to database.', self.cache_name)
with Session() as session:
db_cache = (
session
.query(InputCache)
.filter(InputCache.name == self.name)
.filter(InputCache.hash == self.config_hash)
.first()
)
if not db_cache:
db_cache = InputCache(name=self.name, hash=self.config_hash)
db_cache.entries = [InputCacheEntry(entry=ent) for ent in entries]
db_cache.added = datetime.now()
session.merge(db_cache)
[docs]
def load_from_db(self, load_expired: bool = False) -> list[InputCacheEntry] | None:
with Session() as session:
db_cache = (
session
.query(InputCache)
.filter(InputCache.name == self.name)
.filter(InputCache.hash == self.config_hash)
)
if self.persist and not load_expired:
db_cache = db_cache.filter(InputCache.added > datetime.now() - self.persist)
db_cache = db_cache.first()
if db_cache:
entries = [ent.entry for ent in db_cache.entries]
logger.verbose(f'Restored {len(entries)} entries from db cache for {self.name}')
# Store to in memory cache
self.cache[self.cache_name] = copy.deepcopy(entries)
return entries
return None
[docs]
class IterableCache:
"""Can cache any iterable (including generators) without immediately evaluating all entries.
If `finished_hook` is supplied, it will be called the first time the iterable is run to the end.
"""
def __init__(self, iterable: Iterable, finished_hook: Callable[[list], None] | None = None):
self.iterable = iter(iterable)
self.cache: list = []
self.finished_hook = finished_hook
def __iter__(self):
for item in self.cache:
yield copy.deepcopy(item)
for item in self.iterable:
self.cache.append(item)
yield copy.deepcopy(item)
# The first time we iterate through all items, call our finished hook with complete list of items
if self.finished_hook:
self.finished_hook(self.cache)
self.finished_hook = None