From ddaa2f1621605642d3f54985bca92e72cf7668f9 Mon Sep 17 00:00:00 2001
From: Douglas Raillard <douglas.raillard@arm.com>
Date: Mon, 20 Feb 2023 18:18:11 +0000
Subject: [PATCH] connection: Rework TransferManager

* Split TransferManager and TransferHandle:
    * TransferManager deals with the generic monitoring. To abort a
      transfer, it simply cancels the transfer and raises an exception
      from manage().
    * TransferHandle provides a way for the manager to query the state
      of the transfer and cancel it. It is backend-specific.

* Remove most of the state in TransferManager, along with the associated
  background command leak etc

* Use a daemonic monitor thread to behave as excpected on interpreter
  shutdown.

* Ensure a transfer manager _always_ exists. When no management is
  desired, a noop object is used. This avoids using a None sentinel,
  which is invariably mishandled by some code leading to crashes.

* Try to merge more paths in the code to uncover as many issues as
  possible in testing.

* Fix percentage for SSHTransferHandle (transferred / (remaining +
  transferred) instead of transferred / remaining)

* Rename total_timeout TransferManager parameter and attribute to
  total_transfer_timeout to match the connection name parameter.
---
 devlib/connection.py    | 281 ++++++++++++++++++++--------------------
 devlib/target.py        |   2 +-
 devlib/utils/android.py |  61 +++++----
 devlib/utils/ssh.py     |  79 ++++++-----
 4 files changed, 227 insertions(+), 196 deletions(-)

diff --git a/devlib/connection.py b/devlib/connection.py
index ef0cb27..aacb48a 100644
--- a/devlib/connection.py
+++ b/devlib/connection.py
@@ -14,12 +14,11 @@
 #
 
 from abc import ABC, abstractmethod
-from contextlib import contextmanager
+from contextlib import contextmanager, nullcontext
 from datetime import datetime
 from functools import partial
 from weakref import WeakSet
 from shlex import quote
-from time import monotonic
 import os
 import signal
 import socket
@@ -61,13 +60,27 @@ class ConnectionBase(InitCheckpoint):
     """
     Base class for all connections.
     """
-    def __init__(self):
+    def __init__(
+        self,
+        poll_transfers=False,
+        start_transfer_poll_delay=30,
+        total_transfer_timeout=3600,
+        transfer_poll_period=30,
+    ):
         self._current_bg_cmds = set()
         self._closed = False
         self._close_lock = threading.Lock()
         self.busybox = None
         self.logger = logging.getLogger('Connection')
 
+        self.transfer_manager = TransferManager(
+            self,
+            start_transfer_poll_delay=start_transfer_poll_delay,
+            total_transfer_timeout=total_transfer_timeout,
+            transfer_poll_period=transfer_poll_period,
+        ) if poll_transfers else NoopTransferManager()
+
+
     def cancel_running_command(self):
         bg_cmds = set(self._current_bg_cmds)
         for bg_cmd in bg_cmds:
@@ -405,13 +418,13 @@ class ParamikoBackgroundCommand(BackgroundCommand):
                 b''.join(out[stderr])
             )
 
-        start = monotonic()
+        start = time.monotonic()
 
         while ret is None:
             # Even if ret is not None anymore, we need to drain the streams
             ret = self.poll()
 
-            if timeout is not None and ret is None and monotonic() - start >= timeout:
+            if timeout is not None and ret is None and time.monotonic() - start >= timeout:
                 self.cancel()
                 _stdout, _stderr = create_out()
                 raise subprocess.TimeoutExpired(self.cmd, timeout, _stdout, _stderr)
@@ -563,9 +576,89 @@ class AdbBackgroundCommand(BackgroundCommand):
         return self
 
 
-class TransferManagerBase(ABC):
+class TransferManager:
+    def __init__(self, conn, transfer_poll_period=30, start_transfer_poll_delay=30, total_transfer_timeout=3600):
+        self.conn = conn
+        self.transfer_poll_period = transfer_poll_period
+        self.total_transfer_timeout = total_transfer_timeout
+        self.start_transfer_poll_delay = start_transfer_poll_delay
 
-    def _pull_dest_size(self, dest):
+        self.logger = logging.getLogger('FileTransfer')
+
+    @contextmanager
+    def manage(self, sources, dest, direction, handle):
+        excep = None
+        stop_thread = threading.Event()
+
+        def monitor():
+            nonlocal excep
+
+            def cancel(reason):
+                self.logger.warning(
+                    f'Cancelling file transfer {sources} -> {dest} due to: {reason}'
+                )
+                handle.cancel()
+
+            start_t = time.monotonic()
+            stop_thread.wait(self.start_transfer_poll_delay)
+            while not stop_thread.wait(self.transfer_poll_period):
+                if not handle.isactive():
+                    cancel(reason='transfer inactive')
+                elif time.monotonic() - start_t > self.total_transfer_timeout:
+                    cancel(reason='transfer timed out')
+                    excep = TimeoutError(f'{direction}: {sources} -> {dest}')
+
+        m_thread = threading.Thread(target=monitor, daemon=True)
+        try:
+            m_thread.start()
+            yield self
+        finally:
+            stop_thread.set()
+            m_thread.join()
+            if excep is not None:
+                raise excep
+
+
+class NoopTransferManager:
+    def manage(self, *args, **kwargs):
+        return nullcontext(self)
+
+
+class TransferHandleBase(ABC):
+    def __init__(self, manager):
+        self.manager = manager
+
+    @property
+    def logger(self):
+        return self.manager.logger
+
+    @abstractmethod
+    def isactive(self):
+        pass
+
+    @abstractmethod
+    def cancel(self):
+        pass
+
+
+class PopenTransferHandle(TransferHandleBase):
+    def __init__(self, bg_cmd, dest, direction, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        if direction == 'push':
+            sample_size = self._push_dest_size
+        elif direction == 'pull':
+            sample_size = self._pull_dest_size
+        else:
+            raise ValueError(f'Unknown direction: {direction}')
+
+        self.sample_size = lambda: sample_size(dest)
+
+        self.bg_cmd = bg_cmd
+        self.last_sample = 0
+
+    @staticmethod
+    def _pull_dest_size(dest):
         if os.path.isdir(dest):
             return sum(
                 os.stat(os.path.join(dirpath, f)).st_size
@@ -576,148 +669,58 @@ class TransferManagerBase(ABC):
             return os.stat(dest).st_size
 
     def _push_dest_size(self, dest):
-        cmd = '{} du -s {}'.format(quote(self.conn.busybox), quote(dest))
-        out = self.conn.execute(cmd)
-        try:
-            return int(out.split()[0])
-        except ValueError:
-            return 0
+        conn = self.manager.conn
+        cmd = '{} du -s -- {}'.format(quote(conn.busybox), quote(dest))
+        out = conn.execute(cmd)
+        return int(out.split()[0])
 
-    def __init__(self, conn, poll_period, start_transfer_poll_delay, total_timeout):
-        self.conn = conn
-        self.poll_period = poll_period
-        self.total_timeout = total_timeout
-        self.start_transfer_poll_delay = start_transfer_poll_delay
-
-        self.logger = logging.getLogger('FileTransfer')
-        self.managing = threading.Event()
-        self.transfer_started = threading.Event()
-        self.transfer_completed = threading.Event()
-        self.transfer_aborted = threading.Event()
-
-        self.monitor_thread = None
-        self.sources = None
-        self.dest = None
-        self.direction = None
-
-    @abstractmethod
-    def _cancel(self):
-        pass
-
-    def cancel(self, reason=None):
-        msg = 'Cancelling file transfer {} -> {}'.format(self.sources, self.dest)
-        if reason is not None:
-            msg += ' due to \'{}\''.format(reason)
-        self.logger.warning(msg)
-        self.transfer_aborted.set()
-        self._cancel()
-
-    @abstractmethod
-    def isactive(self):
-        pass
-
-    @contextmanager
-    def manage(self, sources, dest, direction):
-        try:
-            self.sources, self.dest, self.direction = sources, dest, direction
-            m_thread = threading.Thread(target=self._monitor)
-
-            self.transfer_completed.clear()
-            self.transfer_aborted.clear()
-            self.transfer_started.set()
-
-            m_thread.start()
-            yield self
-        except BaseException:
-            self.cancel(reason='exception during transfer')
-            raise
-        finally:
-            self.transfer_completed.set()
-            self.transfer_started.set()
-            m_thread.join()
-            self.transfer_started.clear()
-            self.transfer_completed.clear()
-            self.transfer_aborted.clear()
-
-    def _monitor(self):
-        start_t = monotonic()
-        self.transfer_completed.wait(self.start_transfer_poll_delay)
-        while not self.transfer_completed.wait(self.poll_period):
-            if not self.isactive():
-                self.cancel(reason='transfer inactive')
-            elif monotonic() - start_t > self.total_timeout:
-                self.cancel(reason='transfer timed out')
-
-
-class PopenTransferManager(TransferManagerBase):
-
-    def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600):
-        super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout)
-        self.transfer = None
-        self.last_sample = None
-
-    def _cancel(self):
-        if self.transfer:
-            self.transfer.cancel()
-            self.transfer = None
-            self.last_sample = None
+    def cancel(self):
+        self.bg_cmd.cancel()
 
     def isactive(self):
-        size_fn = self._push_dest_size if self.direction == 'push' else self._pull_dest_size
-        curr_size = size_fn(self.dest)
-        self.logger.debug('Polled file transfer, destination size {}'.format(curr_size))
-        active = True if self.last_sample is None else curr_size > self.last_sample
-        self.last_sample = curr_size
-        return active
-
-    def set_transfer_and_wait(self, popen_bg_cmd):
-        self.transfer = popen_bg_cmd
-        self.last_sample = None
-        ret = self.transfer.wait()
-
-        if ret and not self.transfer_aborted.is_set():
-            raise subprocess.CalledProcessError(ret, self.transfer.popen.args)
-        elif self.transfer_aborted.is_set():
-            raise TimeoutError(self.transfer.popen.args)
+        try:
+            curr_size = self.sample_size()
+        except Exception as e:
+            self.logger.debug(f'File size polling failed: {e}')
+            return True
+        else:
+            self.logger.debug(f'Polled file transfer, destination size: {curr_size}')
+            if curr_size:
+                active = curr_size > self.last_sample
+                self.last_sample = curr_size
+                return active
+            # If the file is empty it will never grow in size, so we assume
+            # everything is going well.
+            else:
+                return True
 
 
-class SSHTransferManager(TransferManagerBase):
+class SSHTransferHandle(TransferHandleBase):
+
+    def __init__(self, handle, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # SFTPClient or SSHClient
+        self.handle = handle
 
-    def __init__(self, conn, poll_period=30, start_transfer_poll_delay=30, total_timeout=3600):
-        super().__init__(conn, poll_period, start_transfer_poll_delay, total_timeout)
-        self.transferer = None
         self.progressed = False
-        self.transferred = None
-        self.to_transfer = None
+        self.transferred = 0
+        self.to_transfer = 0
 
-    def _cancel(self):
-        self.transferer.close()
+    def cancel(self):
+        self.handle.close()
 
     def isactive(self):
         progressed = self.progressed
-        self.progressed = False
-        msg = 'Polled transfer: {}% [{}B/{}B]'
-        pc = format((self.transferred / self.to_transfer) * 100, '.2f')
-        self.logger.debug(msg.format(pc, self.transferred, self.to_transfer))
+        if progressed:
+            self.progressed = False
+            pc = (self.transferred / self.to_transfer) * 100
+            self.logger.debug(
+                f'Polled transfer: {pc:.2f}% [{self.transferred}B/{self.to_transfer}B]'
+            )
         return progressed
 
-    @contextmanager
-    def manage(self, sources, dest, direction, transferer):
-        with super().manage(sources, dest, direction):
-            try:
-                self.progressed = False
-                self.transferer = transferer  # SFTPClient or SCPClient
-                yield self
-            except socket.error as e:
-                if self.transfer_aborted.is_set():
-                    self.transfer_aborted.clear()
-                    method = 'SCP' if self.conn.use_scp else 'SFTP'
-                    raise TimeoutError('{} {}: {} -> {}'.format(method, self.direction, sources, self.dest))
-                else:
-                    raise e
-
-    def progress_cb(self, to_transfer, transferred):
-        if self.transfer_started.is_set():
-            self.progressed = True
-            self.transferred = transferred
-            self.to_transfer = to_transfer
+    def progress_cb(self, transferred, to_transfer):
+        self.progressed = True
+        self.transferred = transferred
+        self.to_transfer = to_transfer
diff --git a/devlib/target.py b/devlib/target.py
index 4fb385f..4eeeb5c 100644
--- a/devlib/target.py
+++ b/devlib/target.py
@@ -2931,7 +2931,7 @@ class ChromeOsTarget(LinuxTarget):
         ssh_conn_params = ['host', 'username', 'password', 'keyfile',
                            'port', 'timeout', 'sudo_cmd',
                            'strict_host_check', 'use_scp',
-                           'total_timeout', 'poll_transfers',
+                           'total_transfer_timeout', 'poll_transfers',
                            'start_transfer_poll_delay']
         self.ssh_connection_settings = {}
         for setting in ssh_conn_params:
diff --git a/devlib/utils/android.py b/devlib/utils/android.py
index 1cae582..cc6c7e4 100755
--- a/devlib/utils/android.py
+++ b/devlib/utils/android.py
@@ -39,7 +39,7 @@ from shlex import quote
 
 from devlib.exception import TargetTransientError, TargetStableError, HostError, TargetTransientCalledProcessError, TargetStableCalledProcessError, AdbRootError
 from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams, get_subprocess
-from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferManager
+from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferHandle
 
 
 logger = logging.getLogger('android')
@@ -278,26 +278,32 @@ class AdbConnection(ConnectionBase):
         self._connected_as_root[self.device] = state
 
     # pylint: disable=unused-argument
-    def __init__(self, device=None, timeout=None, platform=None, adb_server=None,
-                 adb_as_root=False, connection_attempts=MAX_ATTEMPTS,
-                 poll_transfers=False,
-                 start_transfer_poll_delay=30,
-                 total_transfer_timeout=3600,
-                 transfer_poll_period=30,):
-        super().__init__()
+    def __init__(
+        self,
+        device=None,
+        timeout=None,
+        platform=None,
+        adb_server=None,
+        adb_as_root=False,
+        connection_attempts=MAX_ATTEMPTS,
+
+        poll_transfers=False,
+        start_transfer_poll_delay=30,
+        total_transfer_timeout=3600,
+        transfer_poll_period=30,
+    ):
+        super().__init__(
+            poll_transfers=poll_transfers,
+            start_transfer_poll_delay=start_transfer_poll_delay,
+            total_transfer_timeout=total_transfer_timeout,
+            transfer_poll_period=transfer_poll_period,
+        )
         self.timeout = timeout if timeout is not None else self.default_timeout
         if device is None:
             device = adb_get_device(timeout=timeout, adb_server=adb_server)
         self.device = device
         self.adb_server = adb_server
         self.adb_as_root = adb_as_root
-        self.poll_transfers = poll_transfers
-        if poll_transfers:
-            transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay,
-                            'total_timeout': total_transfer_timeout,
-                            'poll_period': transfer_poll_period,
-                            }
-        self.transfer_mgr = PopenTransferManager(self, **transfer_opts) if poll_transfers else None
         lock, nr_active = AdbConnection.active_connections
         with lock:
             nr_active[self.device] += 1
@@ -330,17 +336,24 @@ class AdbConnection(ConnectionBase):
         paths = ' '.join(map(do_quote, paths))
 
         command = "{} {}".format(action, paths)
-        if timeout or not self.poll_transfers:
+        if timeout:
             adb_command(self.device, command, timeout=timeout, adb_server=self.adb_server)
         else:
-            with self.transfer_mgr.manage(sources, dest, action):
-                bg_cmd = adb_command_background(
-                    device=self.device,
-                    conn=self,
-                    command=command,
-                    adb_server=self.adb_server
-                )
-                self.transfer_mgr.set_transfer_and_wait(bg_cmd)
+            bg_cmd = adb_command_background(
+                device=self.device,
+                conn=self,
+                command=command,
+                adb_server=self.adb_server
+            )
+
+            handle = PopenTransferHandle(
+                manager=self.transfer_manager,
+                bg_cmd=bg_cmd,
+                dest=dest,
+                direction=action
+            )
+            with bg_cmd, self.transfer_manager.manage(sources, dest, action, handle):
+                bg_cmd.communicate()
 
     # pylint: disable=unused-argument
     def execute(self, command, timeout=None, check_exit_code=False,
diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py
index 07a365f..1f517a3 100644
--- a/devlib/utils/ssh.py
+++ b/devlib/utils/ssh.py
@@ -59,7 +59,7 @@ from devlib.utils.misc import (which, strip_bash_colors, check_output,
                                sanitize_cmd_template, memoized, redirect_streams)
 from devlib.utils.types import boolean
 from devlib.connection import (ConnectionBase, ParamikoBackgroundCommand, PopenBackgroundCommand,
-                               SSHTransferManager)
+                               SSHTransferHandle)
 
 
 DEFAULT_SSH_SUDO_COMMAND = "sudo -k -p ' ' -S -- sh -c {}"
@@ -295,8 +295,18 @@ class SshConnectionBase(ConnectionBase):
                  platform=None,
                  sudo_cmd=DEFAULT_SSH_SUDO_COMMAND,
                  strict_host_check=True,
+
+                 poll_transfers=False,
+                 start_transfer_poll_delay=30,
+                 total_transfer_timeout=3600,
+                 transfer_poll_period=30,
                  ):
-        super().__init__()
+        super().__init__(
+            poll_transfers=poll_transfers,
+            start_transfer_poll_delay=start_transfer_poll_delay,
+            total_transfer_timeout=total_transfer_timeout,
+            transfer_poll_period=transfer_poll_period,
+        )
         self._connected_as_root = None
         self.host = host
         self.username = username
@@ -337,24 +347,21 @@ class SshConnection(SshConnectionBase):
             platform=platform,
             sudo_cmd=sudo_cmd,
             strict_host_check=strict_host_check,
+
+            poll_transfers=poll_transfers,
+            start_transfer_poll_delay=start_transfer_poll_delay,
+            total_transfer_timeout=total_transfer_timeout,
+            transfer_poll_period=transfer_poll_period,
         )
         self.timeout = timeout if timeout is not None else self.default_timeout
 
         # Allow using scp for file transfer if sftp is not supported
         self.use_scp = use_scp
-        self.poll_transfers=poll_transfers
-        if poll_transfers:
-            transfer_opts = {'start_transfer_poll_delay': start_transfer_poll_delay,
-                            'total_timeout': total_transfer_timeout,
-                            'poll_period': transfer_poll_period,
-                            }
-
         if self.use_scp:
             logger.debug('Using SCP for file transfer')
         else:
             logger.debug('Using SFTP for file transfer')
 
-        self.transfer_mgr = SSHTransferManager(self, **transfer_opts) if poll_transfers else None
         self.client = self._make_client()
         atexit.register(self.close)
 
@@ -426,12 +433,12 @@ class SshConnection(SshConnectionBase):
         return sftp
 
     @functools.lru_cache()
-    def _get_scp(self, timeout):
-        cb = lambda _, to_transfer, transferred: self.transfer_mgr.progress_cb(to_transfer, transferred)
+    def _get_scp(self, timeout, callback=lambda *_: None):
+        cb = lambda _, to_transfer, transferred: callback(to_transfer, transferred)
         return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=cb)
 
-    def _push_file(self, sftp, src, dst):
-        sftp.put(src, dst, callback=self.transfer_mgr.progress_cb)
+    def _push_file(self, sftp, src, dst, callback):
+        sftp.put(src, dst, callback=callback)
 
     @classmethod
     def _path_exists(cls, sftp, path):
@@ -442,7 +449,7 @@ class SshConnection(SshConnectionBase):
         else:
             return True
 
-    def _push_folder(self, sftp, src, dst):
+    def _push_folder(self, sftp, src, dst, callback):
         sftp.mkdir(dst)
         for entry in os.scandir(src):
             name = entry.name
@@ -453,17 +460,17 @@ class SshConnection(SshConnectionBase):
             else:
                 push = self._push_file
 
-            push(sftp, src_path, dst_path)
+            push(sftp, src_path, dst_path, callback)
 
-    def _push_path(self, sftp, src, dst):
+    def _push_path(self, sftp, src, dst, callback=None):
         logger.debug('Pushing via sftp: {} -> {}'.format(src, dst))
         push = self._push_folder if os.path.isdir(src) else self._push_file
-        push(sftp, src, dst)
+        push(sftp, src, dst, callback)
 
-    def _pull_file(self, sftp, src, dst):
-        sftp.get(src, dst, callback=self.transfer_mgr.progress_cb)
+    def _pull_file(self, sftp, src, dst, callback):
+        sftp.get(src, dst, callback=callback)
 
-    def _pull_folder(self, sftp, src, dst):
+    def _pull_folder(self, sftp, src, dst, callback):
         os.makedirs(dst)
         for fileattr in sftp.listdir_attr(src):
             filename = fileattr.filename
@@ -474,15 +481,15 @@ class SshConnection(SshConnectionBase):
             else:
                 pull = self._pull_file
 
-            pull(sftp, src_path, dst_path)
+            pull(sftp, src_path, dst_path, callback)
 
-    def _pull_path(self, sftp, src, dst):
+    def _pull_path(self, sftp, src, dst, callback=None):
         logger.debug('Pulling via sftp: {} -> {}'.format(src, dst))
         try:
-            self._pull_file(sftp, src, dst)
+            self._pull_file(sftp, src, dst, callback)
         except IOError:
             # Maybe that was a directory, so retry as such
-            self._pull_folder(sftp, src, dst)
+            self._pull_folder(sftp, src, dst, callback)
 
     def push(self, sources, dest, timeout=None):
         self._push_pull('push', sources, dest, timeout)
@@ -494,8 +501,13 @@ class SshConnection(SshConnectionBase):
         if action not in ['push', 'pull']:
             raise ValueError("Action must be either `push` or `pull`")
 
-        # If timeout is set, or told not to poll
-        if timeout is not None or not self.poll_transfers:
+        def make_handle(obj):
+            handle = SSHTransferHandle(obj, manager=self.transfer_manager)
+            cm = self.transfer_manager.manage(sources, dest, action, handle)
+            return (handle, cm)
+
+        # If timeout is set
+        if timeout is not None:
             if self.use_scp:
                 scp = self._get_scp(timeout)
                 scp_cmd = getattr(scp, 'put' if action == 'push' else 'get')
@@ -509,20 +521,23 @@ class SshConnection(SshConnectionBase):
                     for source in sources:
                         sftp_cmd(sftp, source, dest)
 
-        # No timeout, and polling is set
+        # No timeout
         elif self.use_scp:
-            scp = self._get_scp(timeout)
+            scp = self._get_scp(timeout, callback=handle.progress_cb)
+            handle, cm = make_handle(scp)
+
             scp_cmd = getattr(scp, 'put' if action == 'push' else 'get')
-            with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, scp):
+            with _handle_paramiko_exceptions(), cm:
                 scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest)
                 logger.debug(scp_msg.capitalize())
                 scp_cmd(sources, dest, recursive=True)
         else:
             sftp = self._get_sftp(timeout)
+            handle, cm = make_handle(sftp)
             sftp_cmd = getattr(self, '_' + action + '_path')
-            with _handle_paramiko_exceptions(), self.transfer_mgr.manage(sources, dest, action, sftp):
+            with _handle_paramiko_exceptions(), cm:
                 for source in sources:
-                    sftp_cmd(sftp, source, dest)
+                    sftp_cmd(sftp, source, dest, callback=handle.progress_cb)
 
     def execute(self, command, timeout=None, check_exit_code=True,
                 as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument