from __future__ import annotations
import contextlib
import logging
import os
import random
import string
import threading
import unittest.mock
from typing import TYPE_CHECKING
import rich.text
import rpyc
from loguru import logger
from rpyc.utils.server import ThreadedServer
from flexget import terminal
from flexget.log import capture_logs
from flexget.options import get_parser
if TYPE_CHECKING:
from collections.abc import Callable
logger = logger.bind(name='ipc')
# Allow some attributes from dict interface to be called over the wire
rpyc.core.protocol.DEFAULT_CONFIG['safe_attrs'].update(['items'])
rpyc.core.protocol.DEFAULT_CONFIG['allow_pickle'] = True
IPC_VERSION = 4
AUTH_ERROR = b'authentication error'
AUTH_SUCCESS = b'authentication success'
[docs]
class RemoteStream:
"""Used as a filelike to stream text to remote client.
If client disconnects while this is in use, an error will be
logged, but no exception raised.
"""
def __init__(self, writer: Callable | None):
""":param writer: A function which writes a line of text to remote client."""
self.buffer = ''
self.writer = writer
[docs]
def write(self, text: str) -> None:
self.buffer += text
if '\n' in self.buffer:
self.flush()
[docs]
def flush(self) -> None:
if self.buffer is None or self.writer is None:
return
try:
self.writer(self.buffer, end='')
except EOFError:
self.writer = None
logger.error('Client ended connection while still streaming output.')
finally:
self.buffer = ''
[docs]
class DaemonService(rpyc.Service):
# This will be populated when the server is started
manager = None
[docs]
def on_connect(self, conn):
self._conn = conn
super().on_connect(conn)
[docs]
def exposed_version(self):
return IPC_VERSION
[docs]
def exposed_handle_cli(self, args):
args = rpyc.utils.classic.obtain(args)
logger.verbose('Running command `{}` for client.', ' '.join(args))
with unittest.mock.patch.dict(os.environ, {'FORCE_COLOR': '1'}):
parser = get_parser()
try:
options = parser.parse_args(args, file=self.client_out_stream)
except SystemExit as e:
if e.code:
# TODO: Not sure how to properly propagate the exit code back to client
logger.debug('Parsing cli args caused system exit with status {}.', e.code)
return
context_managers = []
# Don't capture any output when used with --cron
if not options.cron:
# Monkeypatch the console function to be the one from the client
# This means decisions about color formatting, and table sizes can be delayed and
# decided based on the client terminal capabilities.
context_managers.append(
unittest.mock.patch('flexget.terminal._patchable_console', self._conn.root.console)
)
if options.loglevel != 'NONE':
context_managers.append(capture_logs(self.client_log_sink, level=options.loglevel))
with contextlib.ExitStack() as stack:
for cm in context_managers:
stack.enter_context(cm)
self.manager.handle_cli(options)
@property
def client_out_stream(self):
return RemoteStream(self._conn.root.write)
[docs]
def client_log_sink(self, message):
return self._conn.root.log_sink(message)
[docs]
class ClientService(rpyc.Service):
[docs]
def on_connect(self, conn):
self._conn = conn
"""Make sure the client version matches our own."""
daemon_version = self._conn.root.version()
if daemon_version != IPC_VERSION:
self._conn.close()
raise ValueError('Daemon is different version than client.')
super().on_connect(conn)
[docs]
def exposed_version(self):
return IPC_VERSION
[docs]
def exposed_console(self, text, *args, **kwargs):
text = rpyc.classic.obtain(text)
terminal.console(text, *args, **kwargs)
[docs]
def exposed_write(self, text, *args, **kwargs):
text = rpyc.classic.obtain(text)
text = rich.text.Text.from_ansi(text)
terminal.console(text, *args, **kwargs)
[docs]
def exposed_log_sink(self, message):
message = rpyc.classic.obtain(message)
record = message.record
level, message = record['level'].name, record['message']
logger.patch(lambda r: r.update(record)).log(level, message)
[docs]
class IPCServer:
def __init__(self, manager, port=None):
self.daemon = True
self.manager = manager
self.host = '127.0.0.1'
self.port = port or 0
self.password = ''.join(
random.choice(string.ascii_letters + string.digits) for x in range(15)
)
self.server = None
self._thread = None
[docs]
def start(self):
if not self._thread:
self._thread = threading.Thread(name='ipc_server', target=self.run)
self._thread.start()
[docs]
def authenticator(self, sock):
channel = rpyc.Channel(rpyc.SocketStream(sock))
password = channel.recv().decode('utf-8')
if password != self.password:
channel.send(AUTH_ERROR)
raise rpyc.utils.authenticators.AuthenticationError('Invalid password from client.')
channel.send(AUTH_SUCCESS)
return sock, self.password
[docs]
def run(self):
# Make the rpyc logger a bit quieter when we aren't in debugging.
rpyc_logger = logging.getLogger('ipc.rpyc')
if logger.level(self.manager.options.loglevel).no > logger.level('DEBUG').no:
rpyc_logger.setLevel(logging.WARNING)
DaemonService.manager = self.manager
self.server = ThreadedServer(
DaemonService,
hostname=self.host,
port=self.port,
authenticator=self.authenticator,
logger=rpyc_logger,
# Timeout can happen when piping to 'less' and delaying scrolling to bottom. Make it a long timeout.
protocol_config={'sync_request_timeout': 3600},
)
# If we just chose an open port, write save the chosen one
self.port = self.server.listener.getsockname()[1]
self.manager.write_lock(ipc_info={'port': self.port, 'password': self.password})
self.server.start()
[docs]
def shutdown(self):
if self.server:
self.server.close()
[docs]
class IPCClient:
def __init__(self, port, password: str):
channel = rpyc.Channel(rpyc.SocketStream.connect('127.0.0.1', port))
channel.send(password.encode('utf-8'))
response = channel.recv()
if response == AUTH_ERROR:
# TODO: What to raise here. I guess we create a custom error
raise ValueError('Invalid password for daemon')
self.conn = rpyc.utils.factory.connect_channel(
channel, service=ClientService, config={'sync_request_timeout': None}
)
[docs]
def close(self):
self.conn.close()
def __getattr__(self, item):
"""Proxy all other calls to the exposed daemon service."""
return getattr(self.conn.root, item)