diff --git a/devlib/connection.py b/devlib/connection.py index ef0cb27..aacb48a 100644 --- a/devlib/connection.py +++ b/devlib/connection.py @@ -14,12 +14,11 @@ # from abc import ABC, abstractmethod -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from datetime import datetime from functools import partial from weakref import WeakSet from shlex import quote -from time import monotonic import os import signal import socket @@ -61,13 +60,27 @@ class ConnectionBase(InitCheckpoint): """ Base class for all connections. """ - def __init__(self): + def __init__( + self, + poll_transfers=False, + start_transfer_poll_delay=30, + total_transfer_timeout=3600, + transfer_poll_period=30, + ): self._current_bg_cmds = set() self._closed = False self._close_lock = threading.Lock() self.busybox = None self.logger = logging.getLogger('Connection') + self.transfer_manager = TransferManager( + self, + start_transfer_poll_delay=start_transfer_poll_delay, + total_transfer_timeout=total_transfer_timeout, + transfer_poll_period=transfer_poll_period, + ) if poll_transfers else NoopTransferManager() + + def cancel_running_command(self): bg_cmds = set(self._current_bg_cmds) for bg_cmd in bg_cmds: @@ -405,13 +418,13 @@ class ParamikoBackgroundCommand(BackgroundCommand): b''.join(out[stderr]) ) - start = monotonic() + start = time.monotonic() while ret is None: # Even if ret is not None anymore, we need to drain the streams ret = self.poll() - if timeout is not None and ret is None and monotonic() - start >= timeout: + if timeout is not None and ret is None and time.monotonic() - start >= timeout: self.cancel() _stdout, _stderr = create_out() raise subprocess.TimeoutExpired(self.cmd, timeout, _stdout, _stderr) @@ -563,9 +576,89 @@ class AdbBackgroundCommand(BackgroundCommand): return self -class TransferManagerBase(ABC): +class TransferManager: + def __init__(self, conn, transfer_poll_period=30, start_transfer_poll_delay=30, total_transfer_timeout=3600): + self.conn = conn + self.transfer_poll_period = transfer_poll_period + self.total_transfer_timeout = total_transfer_timeout + self.start_transfer_poll_delay = start_transfer_poll_delay - def _pull_dest_size(self, dest): + self.logger = logging.getLogger('FileTransfer') + + @contextmanager + def manage(self, sources, dest, direction, handle): + excep = None + stop_thread = threading.Event() + + def monitor(): + nonlocal excep + + def cancel(reason): + self.logger.warning( + f'Cancelling file transfer {sources} -> {dest} due to: {reason}' + ) + handle.cancel() + + start_t = time.monotonic() + stop_thread.wait(self.start_transfer_poll_delay) + while not stop_thread.wait(self.transfer_poll_period): + if not handle.isactive(): + cancel(reason='transfer inactive') + elif time.monotonic() - start_t > self.total_transfer_timeout: + cancel(reason='transfer timed out') + excep = TimeoutError(f'{direction}: {sources} -> {dest}') + + m_thread = threading.Thread(target=monitor, daemon=True) + try: + m_thread.start() + yield self + finally: + stop_thread.set() + m_thread.join() + if excep is not None: + raise excep + + +class NoopTransferManager: + def manage(self, *args, **kwargs): + return nullcontext(self) + + +class TransferHandleBase(ABC): + def __init__(self, manager): + self.manager = manager + + @property + def logger(self): + return self.manager.logger + + @abstractmethod + def isactive(self): + pass + + @abstractmethod + def cancel(self): + pass + + +class PopenTransferHandle(TransferHandleBase): + def __init__(self, bg_cmd, dest, direction, *args, **kwargs): + super().__init__(*args, **kwargs) + + if direction == 'push': + sample_size = self._push_dest_size + elif direction == 'pull': + sample_size = self._pull_dest_size + else: + raise ValueError(f'Unknown direction: {direction}') + + self.sample_size = lambda: sample_size(dest) + + self.bg_cmd = bg_cmd + self.last_sample = 0 + + @staticmethod + def _pull_dest_size(dest): if os.path.isdir(dest): return sum( os.stat(os.path.join(dirpath, f)).st_size @@ -576,148 +669,58 @@ class TransferManagerBase(ABC): return os.stat(dest).st_size def _push_dest_size(self, dest): - cmd = '{} du -s {}'.format(quote(self.conn.busybox), quote(dest)) - out = self.conn.execute(cmd) - try: - return int(out.split()[0]) - except ValueError: - return 0 + conn = self.manager.conn + cmd = '{} du -s -- {}'.format(quote(conn.busybox), quote(dest)) + out = conn.execute(cmd) + return int(out.split()[0]) - def __init__(self, conn, poll_period, start_transfer_poll_delay, total_timeout): - self.conn = conn - self.poll_period = poll_period - self.total_timeout = total_timeout - self.start_transfer_poll_delay = start_transfer_poll_delay - - self.logger = logging.getLogger('FileTransfer') - self.managing = threading.Event() - self.transfer_started = threading.Event() - self.transfer_completed = threading.Event() - self.transfer_aborted = threading.Event() - - self.monitor_thread = None - self.sources = None - self.dest = None - self.direction = None - - @abstractmethod - def _cancel(self): - pass - - def cancel(self, reason=None): - msg = 'Cancelling file transfer {} -> {}'.format(self.sources, self.dest) - if reason is not None: - msg += ' due to \'{}\''.format(reason) - self.logger.warning(msg) - self.transfer_aborted.set() - self._cancel() - - @abstractmethod - def isactive(self): - pass - - @contextmanager - def manage(self, sources, dest, direction): - try: - self.sources, self.dest, self.direction = sources, dest, direction - m_thread = threading.Thread(target=self._monitor) - - self.transfer_completed.clear() - self.transfer_aborted.clear() - self.transfer_started.set() - - m_thread.start() - yield self - except BaseException: - self.cancel(reason='exception during transfer') - raise - finally: - self.transfer_completed.set() - self.transfer_started.set() - m_thread.join() - self.transfer_started.clear() - self.transfer_completed.clear() - self.transfer_aborted.clear() - - def _monitor(self): - start_t = monotonic() - self.transfer_completed.wait(self.start_transfer_poll_delay) - while not self.transfer_completed.wait(self.poll_period): - if not self.isactive(): - self.cancel(reason='transfer inactive') - elif monotonic() - start_t > self.total_timeout: - self.cancel(reason='transfer timed out') - - -class PopenTransferManager(TransferManagerBase): - - def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600): - super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout) - self.transfer = None - self.last_sample = None - - def _cancel(self): - if self.transfer: - self.transfer.cancel() - self.transfer = None - self.last_sample = None + def cancel(self): + self.bg_cmd.cancel() def isactive(self): - size_fn = self._push_dest_size if self.direction == 'push' else self._pull_dest_size - curr_size = size_fn(self.dest) - self.logger.debug('Polled file transfer, destination size {}'.format(curr_size)) - active = True if self.last_sample is None else curr_size > self.last_sample - self.last_sample = curr_size - return active - - def set_transfer_and_wait(self, popen_bg_cmd): - self.transfer = popen_bg_cmd - self.last_sample = None - ret = self.transfer.wait() - - if ret and not self.transfer_aborted.is_set(): - raise subprocess.CalledProcessError(ret, self.transfer.popen.args) - elif self.transfer_aborted.is_set(): - raise TimeoutError(self.transfer.popen.args) + try: + curr_size = self.sample_size() + except Exception as e: + self.logger.debug(f'File size polling failed: {e}') + return True + else: + self.logger.debug(f'Polled file transfer, destination size: {curr_size}') + if curr_size: + active = curr_size > self.last_sample + self.last_sample = curr_size + return active + # If the file is empty it will never grow in size, so we assume + # everything is going well. + else: + return True -class SSHTransferManager(TransferManagerBase): +class SSHTransferHandle(TransferHandleBase): + + def __init__(self, handle, *args, **kwargs): + super().__init__(*args, **kwargs) + + # SFTPClient or SSHClient + self.handle = handle - def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600): - super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout) - self.transferer = None self.progressed = False - self.transferred = None - self.to_transfer = None + self.transferred = 0 + self.to_transfer = 0 - def _cancel(self): - self.transferer.close() + def cancel(self): + self.handle.close() def isactive(self): progressed = self.progressed - self.progressed = False - msg = 'Polled transfer: {}% [{}B/{}B]' - pc = format((self.transferred / self.to_transfer) * 100, '.2f') - self.logger.debug(msg.format(pc, self.transferred, self.to_transfer)) + if progressed: + self.progressed = False + pc = (self.transferred / self.to_transfer) * 100 + self.logger.debug( + f'Polled transfer: {pc:.2f}% [{self.transferred}B/{self.to_transfer}B]' + ) return progressed - @contextmanager - def manage(self, sources, dest, direction, transferer): - with super().manage(sources, dest, direction): - try: - self.progressed = False - self.transferer = transferer # SFTPClient or SCPClient - yield self - except socket.error as e: - if self.transfer_aborted.is_set(): - self.transfer_aborted.clear() - method = 'SCP' if self.conn.use_scp else 'SFTP' - raise TimeoutError('{} {}: {} -> {}'.format(method, self.direction, sources, self.dest)) - else: - raise e - - def progress_cb(self, to_transfer, transferred): - if self.transfer_started.is_set(): - self.progressed = True - self.transferred = transferred - self.to_transfer = to_transfer + def progress_cb(self, transferred, to_transfer): + self.progressed = True + self.transferred = transferred + self.to_transfer = to_transfer diff --git a/devlib/target.py b/devlib/target.py index 4fb385f..4eeeb5c 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -2931,7 +2931,7 @@ class ChromeOsTarget(LinuxTarget): ssh_conn_params = ['host', 'username', 'password', 'keyfile', 'port', 'timeout', 'sudo_cmd', 'strict_host_check', 'use_scp', - 'total_timeout', 'poll_transfers', + 'total_transfer_timeout', 'poll_transfers', 'start_transfer_poll_delay'] self.ssh_connection_settings = {} for setting in ssh_conn_params: diff --git a/devlib/utils/android.py b/devlib/utils/android.py index 1cae582..cc6c7e4 100755 --- a/devlib/utils/android.py +++ b/devlib/utils/android.py @@ -39,7 +39,7 @@ from shlex import quote from devlib.exception import TargetTransientError, TargetStableError, HostError, TargetTransientCalledProcessError, TargetStableCalledProcessError, AdbRootError from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams, get_subprocess -from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferManager +from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferHandle logger = logging.getLogger('android') @@ -278,26 +278,32 @@ class AdbConnection(ConnectionBase): self._connected_as_root[self.device] = state # pylint: disable=unused-argument - def __init__(self, device=None, timeout=None, platform=None, adb_server=None, - adb_as_root=False, connection_attempts=MAX_ATTEMPTS, - poll_transfers=False, - start_transfer_poll_delay=30, - total_transfer_timeout=3600, - transfer_poll_period=30,): - super().__init__() + def __init__( + self, + device=None, + timeout=None, + platform=None, + adb_server=None, + adb_as_root=False, + connection_attempts=MAX_ATTEMPTS, + + poll_transfers=False, + start_transfer_poll_delay=30, + total_transfer_timeout=3600, + transfer_poll_period=30, + ): + super().__init__( + poll_transfers=poll_transfers, + start_transfer_poll_delay=start_transfer_poll_delay, + total_transfer_timeout=total_transfer_timeout, + transfer_poll_period=transfer_poll_period, + ) self.timeout = timeout if timeout is not None else self.default_timeout if device is None: device = adb_get_device(timeout=timeout, adb_server=adb_server) self.device = device self.adb_server = adb_server self.adb_as_root = adb_as_root - self.poll_transfers = poll_transfers - if poll_transfers: - transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay, - 'total_timeout': total_transfer_timeout, - 'poll_period': transfer_poll_period, - } - self.transfer_mgr = PopenTransferManager(self, **transfer_opts) if poll_transfers else None lock, nr_active = AdbConnection.active_connections with lock: nr_active[self.device] += 1 @@ -330,17 +336,24 @@ class AdbConnection(ConnectionBase): paths = ' '.join(map(do_quote, paths)) command = "{} {}".format(action, paths) - if timeout or not self.poll_transfers: + if timeout: adb_command(self.device, command, timeout=timeout, adb_server=self.adb_server) else: - with self.transfer_mgr.manage(sources, dest, action): - bg_cmd = adb_command_background( - device=self.device, - conn=self, - command=command, - adb_server=self.adb_server - ) - self.transfer_mgr.set_transfer_and_wait(bg_cmd) + bg_cmd = adb_command_background( + device=self.device, + conn=self, + command=command, + adb_server=self.adb_server + ) + + handle = PopenTransferHandle( + manager=self.transfer_manager, + bg_cmd=bg_cmd, + dest=dest, + direction=action + ) + with bg_cmd, self.transfer_manager.manage(sources, dest, action, handle): + bg_cmd.communicate() # pylint: disable=unused-argument def execute(self, command, timeout=None, check_exit_code=False, diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py index 07a365f..1f517a3 100644 --- a/devlib/utils/ssh.py +++ b/devlib/utils/ssh.py @@ -59,7 +59,7 @@ from devlib.utils.misc import (which, strip_bash_colors, check_output, sanitize_cmd_template, memoized, redirect_streams) from devlib.utils.types import boolean from devlib.connection import (ConnectionBase, ParamikoBackgroundCommand, PopenBackgroundCommand, - SSHTransferManager) + SSHTransferHandle) DEFAULT_SSH_SUDO_COMMAND = "sudo -k -p ' ' -S -- sh -c {}" @@ -295,8 +295,18 @@ class SshConnectionBase(ConnectionBase): platform=None, sudo_cmd=DEFAULT_SSH_SUDO_COMMAND, strict_host_check=True, + + poll_transfers=False, + start_transfer_poll_delay=30, + total_transfer_timeout=3600, + transfer_poll_period=30, ): - super().__init__() + super().__init__( + poll_transfers=poll_transfers, + start_transfer_poll_delay=start_transfer_poll_delay, + total_transfer_timeout=total_transfer_timeout, + transfer_poll_period=transfer_poll_period, + ) self._connected_as_root = None self.host = host self.username = username @@ -337,24 +347,21 @@ class SshConnection(SshConnectionBase): platform=platform, sudo_cmd=sudo_cmd, strict_host_check=strict_host_check, + + poll_transfers=poll_transfers, + start_transfer_poll_delay=start_transfer_poll_delay, + total_transfer_timeout=total_transfer_timeout, + transfer_poll_period=transfer_poll_period, ) self.timeout = timeout if timeout is not None else self.default_timeout # Allow using scp for file transfer if sftp is not supported self.use_scp = use_scp - self.poll_transfers=poll_transfers - if poll_transfers: - transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay, - 'total_timeout': total_transfer_timeout, - 'poll_period': transfer_poll_period, - } - if self.use_scp: logger.debug('Using SCP for file transfer') else: logger.debug('Using SFTP for file transfer') - self.transfer_mgr = SSHTransferManager(self, **transfer_opts) if poll_transfers else None self.client = self._make_client() atexit.register(self.close) @@ -426,12 +433,12 @@ class SshConnection(SshConnectionBase): return sftp @functools.lru_cache() - def _get_scp(self, timeout): - cb = lambda _, to_transfer, transferred: self.transfer_mgr.progress_cb(to_transfer, transferred) + def _get_scp(self, timeout, callback=lambda *_: None): + cb = lambda _, to_transfer, transferred: callback(to_transfer, transferred) return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=cb) - def _push_file(self, sftp, src, dst): - sftp.put(src, dst, callback=self.transfer_mgr.progress_cb) + def _push_file(self, sftp, src, dst, callback): + sftp.put(src, dst, callback=callback) @classmethod def _path_exists(cls, sftp, path): @@ -442,7 +449,7 @@ class SshConnection(SshConnectionBase): else: return True - def _push_folder(self, sftp, src, dst): + def _push_folder(self, sftp, src, dst, callback): sftp.mkdir(dst) for entry in os.scandir(src): name = entry.name @@ -453,17 +460,17 @@ class SshConnection(SshConnectionBase): else: push = self._push_file - push(sftp, src_path, dst_path) + push(sftp, src_path, dst_path, callback) - def _push_path(self, sftp, src, dst): + def _push_path(self, sftp, src, dst, callback=None): logger.debug('Pushing via sftp: {} -> {}'.format(src, dst)) push = self._push_folder if os.path.isdir(src) else self._push_file - push(sftp, src, dst) + push(sftp, src, dst, callback) - def _pull_file(self, sftp, src, dst): - sftp.get(src, dst, callback=self.transfer_mgr.progress_cb) + def _pull_file(self, sftp, src, dst, callback): + sftp.get(src, dst, callback=callback) - def _pull_folder(self, sftp, src, dst): + def _pull_folder(self, sftp, src, dst, callback): os.makedirs(dst) for fileattr in sftp.listdir_attr(src): filename = fileattr.filename @@ -474,15 +481,15 @@ class SshConnection(SshConnectionBase): else: pull = self._pull_file - pull(sftp, src_path, dst_path) + pull(sftp, src_path, dst_path, callback) - def _pull_path(self, sftp, src, dst): + def _pull_path(self, sftp, src, dst, callback=None): logger.debug('Pulling via sftp: {} -> {}'.format(src, dst)) try: - self._pull_file(sftp, src, dst) + self._pull_file(sftp, src, dst, callback) except IOError: # Maybe that was a directory, so retry as such - self._pull_folder(sftp, src, dst) + self._pull_folder(sftp, src, dst, callback) def push(self, sources, dest, timeout=None): self._push_pull('push', sources, dest, timeout) @@ -494,8 +501,13 @@ class SshConnection(SshConnectionBase): if action not in ['push', 'pull']: raise ValueError("Action must be either `push` or `pull`") - # If timeout is set, or told not to poll - if timeout is not None or not self.poll_transfers: + def make_handle(obj): + handle = SSHTransferHandle(obj, manager=self.transfer_manager) + cm = self.transfer_manager.manage(sources, dest, action, handle) + return (handle, cm) + + # If timeout is set + if timeout is not None: if self.use_scp: scp = self._get_scp(timeout) scp_cmd = getattr(scp, 'put' if action == 'push' else 'get') @@ -509,20 +521,23 @@ class SshConnection(SshConnectionBase): for source in sources: sftp_cmd(sftp, source, dest) - # No timeout, and polling is set + # No timeout elif self.use_scp: - scp = self._get_scp(timeout) + scp = self._get_scp(timeout, callback=handle.progress_cb) + handle, cm = make_handle(scp) + scp_cmd = getattr(scp, 'put' if action == 'push' else 'get') - with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, scp): + with _handle_paramiko_exceptions(), cm: scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest) logger.debug(scp_msg.capitalize()) scp_cmd(sources, dest, recursive=True) else: sftp = self._get_sftp(timeout) + handle, cm = make_handle(sftp) sftp_cmd = getattr(self, '_' + action + '_path') - with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, sftp): + with _handle_paramiko_exceptions(), cm: for source in sources: - sftp_cmd(sftp, source, dest) + sftp_cmd(sftp, source, dest, callback=handle.progress_cb) def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument