Source code for flexget.components.ftp.sftp_client

from __future__ import annotations

import importlib
import logging
import time
from base64 import b64decode
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from pathlib import Path, PurePosixPath
from stat import S_ISLNK
from urllib.parse import quote, urljoin

from loguru import logger

from flexget import plugin
from flexget.entry import Entry
from flexget.task import TaskAbort

# retry configuration constants
RETRY_INTERVAL_SEC: int = 15
RETRY_STEP_SEC: int = 5
HOST_KEY_TYPES: dict = {
    'ssh-rsa': 'RSAKey',
    'ssh-ed25519': 'Ed25519Key',
}

try:
    import paramiko
    import pysftp

    logging.getLogger('paramiko').setLevel(logging.ERROR)
except ImportError:
    pysftp = None

NodeHandler = Callable[[str], None]

logger = logger.bind(name='sftp_client')


[docs] def _set_authentication_patch(self, password, private_key, private_key_pass): """Patch pysftp.Connection._set_authentication to support additional key types.""" if password is None: # Use Private Key. if not private_key: # Try to use default key. if Path('~/.ssh/id_rsa').expanduser().exists(): private_key = '~/.ssh/id_rsa' elif Path('~/.ssh/id_dsa').expanduser().exists(): private_key = '~/.ssh/id_dsa' else: raise pysftp.exceptions.CredentialException('No password or key specified.') if isinstance(private_key, (paramiko.AgentKey, paramiko.RSAKey)): # use the paramiko agent or rsa key self._tconnect['pkey'] = private_key else: # isn't a paramiko AgentKey or RSAKey, try to build a # key from what we assume is a path to a key private_key_file = Path(private_key).expanduser() for key in [paramiko.RSAKey, paramiko.DSSKey, paramiko.Ed25519Key, paramiko.ECDSAKey]: try: # try all the keys self._tconnect['pkey'] = key.from_private_key_file( private_key_file, private_key_pass ) except paramiko.SSHException: # if it fails, try dss pass else: return raise paramiko.SSHException(f'Unknown key type: {private_key}')
[docs] @dataclass class HostKey: """Host key used to connect to a SFTP server if not defined in known_hosts.""" key_type: str public_key: str
[docs] class SftpClient: def __init__( self, host: str, port: int, username: str, password: str | None = None, private_key: str | None = None, private_key_pass: str | None = None, host_key: HostKey | None = None, connection_tries: int = 3, ): if not pysftp: raise plugin.DependencyError( issued_by='sftp_client', missing='pysftp', message='sftp client requires the pysftp Python module.', ) self.host: str = host self.port: int = port self.username: str = username self.password: str | None = password self.private_key: str | None = private_key self.private_key_pass: str | None = private_key_pass self.host_key: HostKey | None = host_key self.prefix: str = self._get_prefix() self._sftp: pysftp.Connection = self._connect(connection_tries) self._handler_builder: HandlerBuilder = HandlerBuilder( self._sftp, self.prefix, self.private_key, self.private_key_pass, self.host_key )
[docs] def list_directories( self, directories: list[str], recursive: bool, get_size: bool, files_only: bool, dirs_only: bool, ) -> list[Entry]: """Build a list of entries from a provided list of directories on an SFTP server. :param directories: list of directories to generate entries for :param recursive: boolean indicating whether to list recursively :param get_size: boolean indicating whether to compute size for each node (potentially slow for directories) :param files_only: boolean indicating whether to exclude directories :param dirs_only: boolean indicating whether to exclude files :return: a list of entries describing the contents of the provided directories """ entries: list[Entry] = [] dir_handler: NodeHandler = self._handler_builder.get_dir_handler( get_size, files_only, entries ) file_handler: NodeHandler = self._handler_builder.get_file_handler( get_size, dirs_only, entries ) unknown_handler: NodeHandler = self._handler_builder.get_unknown_handler() for directory in directories: try: # Always normalize the root path so it's not necessary to normalised # nodes as there are discovered, which means that symlinks will appear # in the entry paths raw rather than been resolved to their target. self._sftp.walktree( self._sftp.normalize(directory), file_handler, dir_handler, unknown_handler, recursive, ) except OSError as e: logger.warning('Failed to open {} ({})', directory, str(e)) continue return entries
[docs] def download(self, source: str, to: str, recursive: bool, delete_origin: bool) -> None: """Download the file specified in "source" to the destination specified in "to". :param source: path of the resource to download :param to: path of the directory to download to :param recursive: indicates whether to download the contents of "source" recursively :param delete_origin: indicates whether to delete the source resource upon download, is the source is a symlink, only the symlink will be removed rather than it's target. """ dir_handler: NodeHandler = self._handler_builder.get_null_handler() unknown_handler: NodeHandler = self._handler_builder.get_unknown_handler() parsed_path: PurePosixPath = PurePosixPath(source) if not self.path_exists(source): raise SftpError(f'Remote path does not exist: {source}') is_symlink: bool = self.is_link(source) if self.is_file(source): source_file: str = parsed_path.name source_dir: str = str(parsed_path.parent) try: self._sftp.cwd(source_dir) self._download_file(to, delete_origin and not is_symlink, source_file) except Exception as e: raise SftpError(f'Failed to download file {source} ({e!s})') if delete_origin and is_symlink: self.remove_file(source) elif self.is_dir(source): base_path: str = str(parsed_path.parent) dir_name: str = parsed_path.name handle_file: NodeHandler = partial( self._download_file, to, delete_origin and not is_symlink ) try: self._sftp.cwd(base_path) self._sftp.walktree(dir_name, handle_file, dir_handler, unknown_handler, recursive) except Exception as e: raise SftpError(f'Failed to download directory {source} ({e!s})') if delete_origin: if self.is_link(source): self.remove_file(source) else: self.remove_dir(source) else: logger.warning('Skipping unknown file: {}', source)
[docs] def upload(self, source: Path, to: str) -> None: """Upload files or directories to an SFTP server. :param source: file or directory to upload :param to: destination """ if source.is_dir(): logger.verbose('Skipping directory {}', source) else: self._upload_file(source, to)
[docs] def remove_dir(self, path: str) -> None: """Remove a directory if it's empty. :param path: directory to remove """ if self._sftp.exists(path) and not self._sftp.listdir(path): logger.debug('Attempting to delete directory {}', path) try: self._sftp.rmdir(path) except Exception as e: logger.error('Failed to delete directory {} ({})', path, str(e))
[docs] def remove_file(self, path: str) -> None: """Remove a file if it's empty. :param path: file to remove """ logger.debug('Deleting remote file {}', path) try: self._sftp.remove(path) except Exception as e: logger.error('Failed to delete file {} ({})', path, str(e)) return
[docs] def is_file(self, path: str) -> bool: """Check if the node at a given path is a file. :param path: path to check :return: boolean indicating if the path is a file """ return self._sftp.isfile(path)
[docs] def is_dir(self, path: str) -> bool: """Check if the node at a given path is a directory. :param path: path to check :return: boolean indicating if the path is a directory """ return self._sftp.isdir(path)
[docs] def path_exists(self, path: str) -> bool: """Check of a path exists. :param path: Path to check :return: boolean indicating if the path exists """ return self._sftp.lexists(path)
[docs] def make_dirs(self, path: str) -> None: """Build directories. :param path: path to build """ if not self.path_exists(path): try: self._sftp.makedirs(path) except Exception as e: raise SftpError(f'Failed to create remote directory {path} ({e!s})')
[docs] def close(self) -> None: """Close the sftp connection.""" self._sftp.close()
[docs] def set_socket_timeout(self, socket_timeout_sec): """Set the SFTP client socket timeout. :param socket_timeout_sec: Socket timeout in seconds """ self._sftp.timeout = socket_timeout_sec
[docs] def _connect(self, connection_tries: int) -> pysftp.Connection: tries: int = connection_tries retry_interval: int = RETRY_INTERVAL_SEC logger.debug('Connecting to {}', self.host) sftp: pysftp.Connection | None = None while not sftp: try: pysftp.Connection._set_authentication = _set_authentication_patch sftp = pysftp.Connection( host=self.host, username=self.username, private_key=self.private_key, password=self.password, port=self.port, private_key_pass=self.private_key_pass, cnopts=self._get_cnopts(), ) logger.verbose('Connected to {}', self.host) except Exception as e: tries -= 1 logger.debug('Caught exception: {}', e) if not tries: raise TaskAbort(f'Failed to connect to {self.host}') logger.debug('Caught exception: {}', e) logger.warning( 'Failed to connect to {}; waiting {} seconds before retrying.', self.host, retry_interval, ) time.sleep(retry_interval) retry_interval += RETRY_STEP_SEC return sftp
[docs] def _get_cnopts(self) -> pysftp.CnOpts | None: if not self.host_key: return None KeyClass = getattr( # noqa: N806 It's a class importlib.import_module('paramiko'), HOST_KEY_TYPES[self.host_key.key_type] ) key = KeyClass(data=b64decode(self.host_key.public_key)) cnopts = pysftp.CnOpts() cnopts.hostkeys.add(self.host, self.host_key.key_type, key) return cnopts
[docs] def _upload_file(self, source: Path, to: str) -> None: if not source.exists(): logger.warning('File no longer exists:', source) return destination = self._get_upload_path(source, to) destination_url: str = urljoin(self.prefix, destination) if not self.path_exists(to): try: self.make_dirs(to) except Exception as e: raise SftpError(f'Failed to create remote directory {to} ({e!s})') if not self.is_dir(to): raise SftpError(f'Not a directory: {to}') try: self._put_file(source, destination) logger.verbose('Successfully uploaded {} to {}', source, destination_url) except OSError: raise SftpError(f'Remote directory does not exist: {to}') except Exception as e: raise SftpError(f'Failed to upload {source} ({e!s})')
[docs] def _download_file(self, destination: str, delete_origin: bool, source: str) -> None: destination_path: str = self._get_download_path(source, destination) destination_dir: str = str(Path(destination_path).parent) if Path(destination_path).exists(): logger.verbose( 'Skipping {} because destination file {} already exists.', source, destination_path ) return Path(destination_dir).mkdir(parents=True, exist_ok=True) logger.verbose('Downloading file {} to {}', source, destination) try: self._sftp.get(source, destination_path) except Exception as e: logger.error('Failed to download {} ({})', source, e) if Path(destination_path).exists(): logger.debug('Removing partially downloaded file {}', destination_path) Path(destination_path).unlink() raise if delete_origin: self.remove_file(source)
[docs] def _put_file(self, source: Path, destination: str) -> None: return self._sftp.put(str(source), destination)
[docs] def _get_prefix(self) -> str: """Generate SFTP URL prefix.""" def get_login_string() -> str: if self.username and self.password: return f'{self.username}:{self.password}@' if self.username: return f'{self.username}@' return '' def get_port_string() -> str: if self.port and self.port != 22: return f':{self.port}' return '' login_string = get_login_string() host = self.host port_string = get_port_string() return f'sftp://{login_string}{host}{port_string}/'
[docs] @staticmethod def _get_download_path(path: str, destination: str) -> str: return str(PurePosixPath(destination) / path)
[docs] @staticmethod def _get_upload_path(source: Path, to: str): basename: str = source.name return str(PurePosixPath(to, basename))
[docs] class SftpError(Exception): pass
[docs] class HandlerBuilder: """Class for building pysftp.Connection.walktree node handlers. :param sftp: A Connection object :param logger: a logger object :param url_prefix: SFTP URL prefix """ def __init__( self, sftp: pysftp.Connection, url_prefix: str, private_key: str | None, private_key_pass: str | None, host_key: HostKey | None, ): self._sftp = sftp self._prefix = url_prefix self._private_key = private_key self._private_key_pass = private_key_pass self._host_key = host_key
[docs] def get_file_handler( self, get_size: bool, dirs_only: bool, entry_accumulator: list ) -> NodeHandler: """Build a file node handler suitable for use with pysftp.Connection.walktree. :param get_size: boolean indicating whether to compute the for each file :param dirs_only: boolean indicating whether to skip files :param entry_accumulator: list to add entries to """ return partial( Handlers.handle_file, self._sftp, self._prefix, get_size, dirs_only, self._private_key, self._private_key_pass, self._host_key, entry_accumulator, )
[docs] def get_dir_handler( self, get_size: bool, files_only: bool, entry_accumulator: list ) -> NodeHandler: """Build a file node handler suitable for use with pysftp.Connection.walktree. :param get_size: boolean indicating whether to compute the for each file :param files_only: Boolean indicating whether to skip directories :param entry_accumulator: list to add entries to """ return partial( Handlers.handle_directory, self._sftp, self._prefix, get_size, files_only, self._private_key, self._private_key_pass, self._host_key, entry_accumulator, )
[docs] def get_unknown_handler(self) -> NodeHandler: """Build an unknown node handler suitable for use with pysftp.Connection.walktree.""" return partial(Handlers.handle_unknown)
[docs] def get_null_handler(self) -> NodeHandler: """Build a noop node handler suitable for use with pysftp.Connection.walktree.""" return partial(Handlers.null_node_handler)
[docs] class Handlers:
[docs] @classmethod def handle_file( cls, sftp: pysftp.Connection, prefix: str, get_size: bool, dirs_only: bool, private_key: str | None, private_key_pass: str | None, host_key: HostKey | None, entry_accumulator: list[Entry], path: str, ) -> None: """File node handler. Adds a file entry to entry_accumulator. :param sftp: A pysftp.Connection object :param logger: a logger object :param prefix: SFTP URL prefix :param get_size: boolean indicating whether to compute the size of each file :param dirs_only: boolean indicating whether to skip files :param private_key: private key path :param private_key_pass: private key password :param host_key: Host key for the remote server if not in known_hosts :param entry_accumulator: a list in which to store entries :param path: path to handle """ if dirs_only: return size_handler = partial(cls._file_size, sftp) entry = cls._get_entry( sftp, prefix, size_handler, get_size, path, private_key, private_key_pass, host_key ) entry_accumulator.append(entry)
[docs] @classmethod def handle_directory( cls, sftp: pysftp.Connection, prefix: str, get_size: bool, files_only: bool, private_key: str | None, private_key_pass: str | None, host_key: HostKey | None, entry_accumulator: list[Entry], path: str, ) -> None: """Directory node handler. Adds a directory entry to entry_accumulator. :param sftp: A pysftp.Connection object :param logger: a logger object :param prefix: SFTP URL prefix :param get_size: boolean indicating whether to compute the size of each directory :param files_only: Boolean indicating whether to skip directories :param entry_accumulator: a list in which to store entries :param private_key: private key path :param private_key_pass: private key password :param host_key: Host key for the remote server if not in known_hosts :param path: path to handle """ if files_only: return dir_size: Callable[[str], int] = partial(cls._dir_size, sftp) entry: Entry = cls._get_entry( sftp, prefix, dir_size, get_size, path, private_key, private_key_pass, host_key ) entry_accumulator.append(entry)
[docs] @staticmethod def handle_unknown(path: str) -> None: """Handle unknown nodes; log a warning. :param logger: a logger object :param path: path to handle """ logger.warning('Skipping unknown file: {}', path)
[docs] @staticmethod def null_node_handler(path: str) -> None: """Handle generic noop node. :param logger: a logger object :param path: path to handle :return: """ logger.debug('null handler called for {}', path)
[docs] @staticmethod def _get_entry( sftp: pysftp.Connection, prefix: str, size_handler: Callable[[str], int], get_size, path: str, private_key: str | None, private_key_pass: str | None, host_key: HostKey | None, ) -> Entry: url = urljoin(prefix, quote(path)) title = PurePosixPath(path).name entry = Entry(title, url) if get_size: try: size = size_handler(path) except Exception as e: logger.warning('Failed to get size for {} ({})', path, e) size = -1 entry['content_size'] = size entry['private_key'] = private_key entry['private_key_pass'] = private_key_pass if host_key: entry['host_key'] = { 'key_type': host_key.key_type, 'public_key': host_key.public_key, } return entry
[docs] @classmethod def _dir_size(cls, sftp: pysftp.Connection, path: str) -> int: sizes: list[int] = [] size_accumulator = partial(cls._accumulate_file_size, sftp, sizes) sftp.walktree(path, size_accumulator, size_accumulator, size_accumulator, True) return sum(sizes)
[docs] @classmethod def _accumulate_file_size( cls, sftp: pysftp.Connection, size_accumulator: list[int], path: str ) -> None: size_accumulator.append(cls._file_size(sftp, path))
[docs] @staticmethod def _file_size(sftp: pysftp.Connection, path: str) -> int: """Get the size of a file node.""" return sftp.lstat(path).st_size